Training a model with JAX MaxText on ROCm#
2025-10-17
24 min read time
The MaxText for ROCm training Docker image provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs, including essential components like JAX, XLA, ROCm libraries, and MaxText utilities. It includes the following software components:
Software component |
Version |
|---|---|
ROCm |
7.0.0 |
JAX |
0.6.2 |
Python |
3.10.18 |
Transformer Engine |
2.2.0.dev0+c91bac54 |
hipBLASLt |
1.x.x |
MaxText with on ROCm provides the following key features to train large language models efficiently:
Transformer Engine (TE)
Flash Attention (FA) 3 – with or without sequence input packing
GEMM tuning
Multi-node support
NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
Supported models#
The following models are pre-optimized for performance on AMD Instinct GPUs. Some instructions, commands, and available training configurations in this documentation might vary by model – select one to get started.
Note
Some models, such as Llama 3, require an external license agreement through a third party (for example, Meta).
System validation#
Before running AI workloads, it’s important to validate that your AMD hardware is configured correctly and performing optimally.
If you have already validated your system settings, including aspects like NUMA auto-balancing, you can skip this step. Otherwise, complete the procedures in the System validation and optimization guide to properly configure your system settings before starting training.
To test for optimal performance, consult the recommended System health benchmarks. This suite of tests will help you verify and fine-tune your system’s configuration.
Environment setup#
This Docker image is optimized for specific model configurations outlined as follows. Performance can vary for other training workloads, as AMD doesn’t validate configurations and run conditions outside those described.
Pull the Docker image#
Use the following command to pull the Docker image from Docker Hub.
docker pull rocm/jax-training:maxtext-v25.9
Multi-node configuration#
See Multi-node setup for AI workloads to configure your environment for multi-node training.
Benchmarking#
Once the setup is complete, choose between two options to reproduce the benchmark results:
The following run command is tailored to Llama 2 7B. See Supported models to switch to another available model.
Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
Use this command to run the performance benchmark test on the Llama 2 7B model using one GPU with the
bf16data type on the host machine.export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-2-7b \ --keep-model-dir \ --live-output \ --timeout 28800
MAD launches a Docker container with the name
container_ci-jax_maxtext_train_llama-2-7b. The latency and throughput reports of the
model are collected in the following path: ~/MAD/perf.csv/.
The following commands are optimized for Llama 2 7B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Single node training
Set up environment variables.
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>
MAD_SECRETS_HFTOKENis your Hugging Face access token to access models, tokenizers, and data. See User access tokens.HF_HOMEis wherehuggingface_hubwill store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to~/.cache/huggingface.Launch the Docker container.
docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v25.9
In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at
MAD/scripts/jax-maxtext.git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
Run the setup scripts to install libraries and datasets needed for benchmarking.
./jax-maxtext_benchmark_setup.sh -m Llama-2-7B
To run the training benchmark without quantization, use the following command:
./jax-maxtext_benchmark_report.sh -m Llama-2-7B
For quantized training, run the script with the appropriate option for your Instinct GPU.
For
fp8quantized training on MI355X and MI350X GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8
For
nanoo_fp8quantized training on MI300X series GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q nanoo_fp8
Multi-node training
The following examples use SLURM to run on multiple nodes.
Note
The following scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container.
Make sure
$HF_HOMEis set before running the test. See ROCm benchmarking for more details on downloading the Llama models before running the benchmark.To run multi-node training for Llama 2 7B, use the multi-node training script under the
scripts/jax-maxtext/gpu-rocm/directory.Run the multi-node training benchmark script.
sbatch -N <num_nodes> llama2_7b_multinode.sh
The following run command is tailored to Llama 2 70B. See Supported models to switch to another available model.
Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
Use this command to run the performance benchmark test on the Llama 2 70B model using one GPU with the
bf16data type on the host machine.export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-2-70b \ --keep-model-dir \ --live-output \ --timeout 28800
MAD launches a Docker container with the name
container_ci-jax_maxtext_train_llama-2-70b. The latency and throughput reports of the
model are collected in the following path: ~/MAD/perf.csv/.
The following commands are optimized for Llama 2 70B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Single node training
Set up environment variables.
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>
MAD_SECRETS_HFTOKENis your Hugging Face access token to access models, tokenizers, and data. See User access tokens.HF_HOMEis wherehuggingface_hubwill store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to~/.cache/huggingface.Launch the Docker container.
docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v25.9
In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at
MAD/scripts/jax-maxtext.git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
Run the setup scripts to install libraries and datasets needed for benchmarking.
./jax-maxtext_benchmark_setup.sh -m Llama-2-70B
To run the training benchmark without quantization, use the following command:
./jax-maxtext_benchmark_report.sh -m Llama-2-70B
For quantized training, run the script with the appropriate option for your Instinct GPU.
For
fp8quantized training on MI355X and MI350X GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-2-70B -q fp8
For
nanoo_fp8quantized training on MI300X series GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-2-70B -q nanoo_fp8
Multi-node training
The following examples use SLURM to run on multiple nodes.
Note
The following scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container.
Make sure
$HF_HOMEis set before running the test. See ROCm benchmarking for more details on downloading the Llama models before running the benchmark.To run multi-node training for Llama 2 70B, use the multi-node training script under the
scripts/jax-maxtext/gpu-rocm/directory.Run the multi-node training benchmark script.
sbatch -N <num_nodes> llama2_70b_multinode.sh
The following commands are optimized for Llama 3 8B (multi-node). See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Multi-node training
The following examples use SLURM to run on multiple nodes.
Note
The following scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container.
Make sure
$HF_HOMEis set before running the test. See ROCm benchmarking for more details on downloading the Llama models before running the benchmark.To run multi-node training for Llama 3 8B (multi-node), use the multi-node training script under the
scripts/jax-maxtext/gpu-rocm/directory.Run the multi-node training benchmark script.
sbatch -N <num_nodes> llama3_8b_multinode.sh
The following commands are optimized for Llama 3 70B (multi-node). See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Multi-node training
The following examples use SLURM to run on multiple nodes.
Note
The following scripts will launch the Docker container and run the benchmark. Run them outside of any Docker container.
Make sure
$HF_HOMEis set before running the test. See ROCm benchmarking for more details on downloading the Llama models before running the benchmark.To run multi-node training for Llama 3 70B (multi-node), use the multi-node training script under the
scripts/jax-maxtext/gpu-rocm/directory.Run the multi-node training benchmark script.
sbatch -N <num_nodes> llama3_70b_multinode.sh
The following run command is tailored to Llama 3.1 8B. See Supported models to switch to another available model.
Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
Use this command to run the performance benchmark test on the Llama 3.1 8B model using one GPU with the
bf16data type on the host machine.export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-3.1-8b \ --keep-model-dir \ --live-output \ --timeout 28800
MAD launches a Docker container with the name
container_ci-jax_maxtext_train_llama-3.1-8b. The latency and throughput reports of the
model are collected in the following path: ~/MAD/perf.csv/.
The following commands are optimized for Llama 3.1 8B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Single node training
Set up environment variables.
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>
MAD_SECRETS_HFTOKENis your Hugging Face access token to access models, tokenizers, and data. See User access tokens.HF_HOMEis wherehuggingface_hubwill store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to~/.cache/huggingface.Launch the Docker container.
docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v25.9
In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at
MAD/scripts/jax-maxtext.git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
Run the setup scripts to install libraries and datasets needed for benchmarking.
./jax-maxtext_benchmark_setup.sh -m Llama-3.1-8B
To run the training benchmark without quantization, use the following command:
./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B
For quantized training, run the script with the appropriate option for your Instinct GPU.
For
fp8quantized training on MI355X and MI350X GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B -q fp8
For
nanoo_fp8quantized training on MI300X series GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B -q nanoo_fp8
Multi-node training
For multi-node training examples, choose a model from Supported models with an available multi-node training script.
The following run command is tailored to Llama 3.1 70B. See Supported models to switch to another available model.
Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
Use this command to run the performance benchmark test on the Llama 3.1 70B model using one GPU with the
bf16data type on the host machine.export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-3.1-70b \ --keep-model-dir \ --live-output \ --timeout 28800
MAD launches a Docker container with the name
container_ci-jax_maxtext_train_llama-3.1-70b. The latency and throughput reports of the
model are collected in the following path: ~/MAD/perf.csv/.
The following commands are optimized for Llama 3.1 70B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Single node training
Set up environment variables.
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>
MAD_SECRETS_HFTOKENis your Hugging Face access token to access models, tokenizers, and data. See User access tokens.HF_HOMEis wherehuggingface_hubwill store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to~/.cache/huggingface.Launch the Docker container.
docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v25.9
In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at
MAD/scripts/jax-maxtext.git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
Run the setup scripts to install libraries and datasets needed for benchmarking.
./jax-maxtext_benchmark_setup.sh -m Llama-3.1-70B
To run the training benchmark without quantization, use the following command:
./jax-maxtext_benchmark_report.sh -m Llama-3.1-70B
For quantized training, run the script with the appropriate option for your Instinct GPU.
For
fp8quantized training on MI355X and MI350X GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-3.1-70B -q fp8
Multi-node training
For multi-node training examples, choose a model from Supported models with an available multi-node training script.
The following run command is tailored to Llama 3.3 70B. See Supported models to switch to another available model.
Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
Use this command to run the performance benchmark test on the Llama 3.3 70B model using one GPU with the
bf16data type on the host machine.export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_llama-3.3-70b \ --keep-model-dir \ --live-output \ --timeout 28800
MAD launches a Docker container with the name
container_ci-jax_maxtext_train_llama-3.3-70b. The latency and throughput reports of the
model are collected in the following path: ~/MAD/perf.csv/.
The following commands are optimized for Llama 3.3 70B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Single node training
Set up environment variables.
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>
MAD_SECRETS_HFTOKENis your Hugging Face access token to access models, tokenizers, and data. See User access tokens.HF_HOMEis wherehuggingface_hubwill store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to~/.cache/huggingface.Launch the Docker container.
docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v25.9
In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at
MAD/scripts/jax-maxtext.git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
Run the setup scripts to install libraries and datasets needed for benchmarking.
./jax-maxtext_benchmark_setup.sh -m Llama-3.3-70B
To run the training benchmark without quantization, use the following command:
./jax-maxtext_benchmark_report.sh -m Llama-3.3-70B
For quantized training, run the script with the appropriate option for your Instinct GPU.
For
fp8quantized training on MI355X and MI350X GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Llama-3.3-70B -q fp8
Multi-node training
For multi-node training examples, choose a model from Supported models with an available multi-node training script.
The following run command is tailored to DeepSeek-V2-Lite (16B). See Supported models to switch to another available model.
Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
Use this command to run the performance benchmark test on the DeepSeek-V2-Lite (16B) model using one GPU with the
bf16data type on the host machine.export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_deepseek-v2-lite-16b \ --keep-model-dir \ --live-output \ --timeout 28800
MAD launches a Docker container with the name
container_ci-jax_maxtext_train_deepseek-v2-lite-16b. The latency and throughput reports of the
model are collected in the following path: ~/MAD/perf.csv/.
The following commands are optimized for DeepSeek-V2-Lite (16B). See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Single node training
Set up environment variables.
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>
MAD_SECRETS_HFTOKENis your Hugging Face access token to access models, tokenizers, and data. See User access tokens.HF_HOMEis wherehuggingface_hubwill store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to~/.cache/huggingface.Launch the Docker container.
docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v25.9
In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at
MAD/scripts/jax-maxtext.git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
Run the setup scripts to install libraries and datasets needed for benchmarking.
./jax-maxtext_benchmark_setup.sh -m DeepSeek-V2-lite
To run the training benchmark without quantization, use the following command:
./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite
For quantized training, run the script with the appropriate option for your Instinct GPU.
For
fp8quantized training on MI355X and MI350X GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite -q fp8
For
nanoo_fp8quantized training on MI300X series GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite -q nanoo_fp8
Multi-node training
For multi-node training examples, choose a model from Supported models with an available multi-node training script.
The following run command is tailored to Mixtral 8x7B. See Supported models to switch to another available model.
Clone the ROCm Model Automation and Dashboarding (ROCm/MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD cd MAD pip install -r requirements.txt
Use this command to run the performance benchmark test on the Mixtral 8x7B model using one GPU with the
bf16data type on the host machine.export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" madengine run \ --tags jax_maxtext_train_mixtral-8x7b \ --keep-model-dir \ --live-output \ --timeout 28800
MAD launches a Docker container with the name
container_ci-jax_maxtext_train_mixtral-8x7b. The latency and throughput reports of the
model are collected in the following path: ~/MAD/perf.csv/.
The following commands are optimized for Mixtral 8x7B. See Supported models to switch to another available model. Some instructions and resources might not be available for all models and configurations.
Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet.
docker pull rocm/jax-training:maxtext-v25.9
Single node training
Set up environment variables.
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token> export HF_HOME=<Location of saved/cached Hugging Face models>
MAD_SECRETS_HFTOKENis your Hugging Face access token to access models, tokenizers, and data. See User access tokens.HF_HOMEis wherehuggingface_hubwill store local data. See huggingface_hub CLI. If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. Downloaded files typically get cached to~/.cache/huggingface.Launch the Docker container.
docker run -it \ --device=/dev/dri \ --device=/dev/kfd \ --network host \ --ipc host \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --privileged \ -v $HOME:$HOME \ -v $HOME/.ssh:/root/.ssh \ -v $HF_HOME:/hf_cache \ -e HF_HOME=/hf_cache \ -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G \ --name training_env \ rocm/jax-training:maxtext-v25.9
In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at
MAD/scripts/jax-maxtext.git clone https://github.com/ROCm/MAD cd MAD/scripts/jax-maxtext
Run the setup scripts to install libraries and datasets needed for benchmarking.
./jax-maxtext_benchmark_setup.sh -m Mixtral-8x7B
To run the training benchmark without quantization, use the following command:
./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B
For quantized training, run the script with the appropriate option for your Instinct GPU.
For
fp8quantized training on MI355X and MI350X GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B -q fp8
For
nanoo_fp8quantized training on MI300X series GPUs, use the following command:./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B -q nanoo_fp8
Multi-node training
For multi-node training examples, choose a model from Supported models with an available multi-node training script.
Further reading#
To learn more about MAD and the
madengineCLI, see the MAD usage guide.To learn more about system settings and management practices to configure your system for AMD Instinct MI300X Series GPUs, see AMD Instinct MI300X system optimization.
For a list of other ready-made Docker images for AI with ROCm, see AMD Infinity Hub.
Previous versions#
See JAX MaxText training performance testing version history to find documentation for previous releases
of the ROCm/jax-training Docker image.