Multi-GPU JAX Training with AMD rocAL#

This notebook demonstrates how to build an efficient, GPU-accelerated data loading pipeline for the CIFAR-10 dataset using AMD’s ROCm Augmentation Library (rocAL) and integrate it with JAX for multi-GPU model training.

We will explore two common JAX parallelization strategies:

  1. Automatic Parallelization (Sharding): JAX automatically distributes data and computation across multiple devices with minimal code changes.

  2. pmap (Parallel Map): An explicit data parallelism approach where the user has more direct control over how data and models are replicated and computations are performed across devices.

The ROCALJaxIterator from the rocAL library serves as the bridge, seamlessly feeding augmented data from rocAL pipelines to JAX.

1. JAX Multi-GPU Setup#

First, we set up the necessary JAX environment for multi-GPU computation. We create a mesh, which is a logical grid of devices (in this case, a 1D array of all available GPUs). The PositionalSharding helper tells JAX how to distribute data arrays across this mesh.

import jax
from jax.sharding import PositionalSharding
from jax.experimental import mesh_utils
from jax import jit, grad, lax, pmap
from functools import partial
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import numpy as np
import numpy.random as npr

from amd.rocal.pipeline import Pipeline
import amd.rocal.fn as fn
import amd.rocal.types as types
from amd.rocal.plugin.jax import ROCALJaxIterator

# Create a JAX device mesh for multi-GPU training
mesh = mesh_utils.create_device_mesh((jax.device_count(), 1))
sharding = PositionalSharding(mesh)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import jax
      2 from jax.sharding import PositionalSharding
      3 from jax.experimental import mesh_utils

ModuleNotFoundError: No module named 'jax'

2. Data Loading with rocAL for CIFAR-10#

Next, we define our data loading and augmentation pipelines using rocAL. For multi-GPU training, we create a separate pipeline for each GPU. Each pipeline is responsible for reading and processing a unique shard of the dataset, which is a common and efficient strategy for distributed training.

Training Pipeline#

We instantiate one Pipeline for each available device (jax.device_count()). Key parameters:

  • device_id: Assigns the pipeline to a specific GPU.

  • shard_id & num_shards: These parameters tell the cifar10 reader to process a specific fraction of the dataset, ensuring each GPU sees unique data.

  • rocal_cpu=False: Ensures all processing happens on the GPU for maximum performance.

# Training and validation parameters
batch_size = 16
num_epochs = 5
data_path = './data' # NOTE: Update this path to your CIFAR-10 binary data directory
device_count = jax.device_count()


# Create the training pipelines (one per device)
train_pipelines = []
for id in range(device_count):
    train_pipeline = Pipeline(batch_size=batch_size, num_threads=8, device_id=id, seed=id+42, rocal_cpu=False, tensor_dtype = types.FLOAT, tensor_layout=types.NCHW, prefetch_queue_depth = 3, mean=[0.5 * 255,0.5 * 255,0.5 * 255], std = [0.5 * 255,0.5 * 255,0.5 * 255], output_memory_type = types.DEVICE_MEMORY)

    with train_pipeline:
        cifar10_reader_output = fn.readers.cifar10(file_root=f'{data_path}/cifar-10-batches-bin', shard_id=id, num_shards=device_count, filename_prefix='data_batch_', random_shuffle=True, last_batch_policy=types.LAST_BATCH_DROP)
        cmnp = fn.crop_mirror_normalize(cifar10_reader_output,
                                            output_layout = types.NCHW,
                                            output_dtype = types.FLOAT,
                                            crop=(32, 32),
                                            mirror=0,
                                            mean=[0.5 * 255,0.5 * 255,0.5 * 255],
                                            std=[0.5 * 255,0.5 * 255,0.5 * 255])
        train_pipeline.set_outputs(cmnp)

    train_pipeline.build()
    train_pipelines.append(train_pipeline)
    print(f'Training Pipeline created for device {train_pipeline._device_id}')


training_iterator = ROCALJaxIterator(train_pipelines, sharding)
print(f"Number of batches in training iterator = {len(training_iterator)}")

Validation Pipeline#

For validation, we typically run inference on a single GPU. Therefore, we only need to create one pipeline that reads the entire test dataset (num_shards=1). We initialize the corresponding ROCALJaxIterator without the sharding argument, as automatic distribution is not needed.

# Create the validation pipeline (single device)
val_pipeline = Pipeline(batch_size=batch_size, num_threads=8, device_id=0, seed=42, rocal_cpu=False, tensor_dtype = types.FLOAT, tensor_layout=types.NCHW, prefetch_queue_depth = 3, mean=[0.5 * 255,0.5 * 255,0.5 * 255], std = [0.5 * 255,0.5 * 255,0.5 * 255], output_memory_type = types.DEVICE_MEMORY)

with val_pipeline:
    val_cifar10_reader_output = fn.readers.cifar10(file_root=f'{data_path}/cifar-10-batches-bin', shard_id=0, num_shards=1, filename_prefix='test_batch', last_batch_policy=types.LAST_BATCH_DROP)
    val_cmnp = fn.crop_mirror_normalize(val_cifar10_reader_output,
                                            output_layout = types.NCHW,
                                            output_dtype = types.FLOAT,
                                            crop=(32, 32),
                                            mirror=0,
                                            mean=[0.5 * 255,0.5 * 255,0.5 * 255],
                                            std=[0.5 * 255,0.5 * 255,0.5 * 255])
    val_pipeline.set_outputs(val_cmnp)

val_pipeline.build()

validation_iterator = ROCALJaxIterator(val_pipeline)
print(f"Number of batches in validation iterator = {len(validation_iterator)}")

3. Model Definition and Utility Functions#

We define a simple Multi-Layer Perceptron (MLP) for classifying the CIFAR-10 images. We also create standard utility functions for:

  • init_model: Initializing model parameters (weights and biases).

  • predict: The forward pass of the model.

  • loss: Computes the cross-entropy loss.

  • accuracy: Evaluates model performance on the validation set.

# Update model layers for CIFAR10 input size
image_vector_size = 32 * 32 * 3
num_classes = 10
layers = [image_vector_size, 1024, 1024, num_classes]

def init_model(layers=layers, rng=npr.RandomState(0)):
    """Initializes the model's weights and biases."""
    model = []
    for in_size, out_size in zip(layers[:-1], layers[1:]):
        w = 0.1 * rng.randn(in_size, out_size)
        b = 0.1 * rng.randn(out_size)
        model.append((w, b))
    return model

def predict(model, images):
    """Computes the model's predictions."""
    input = images
    for w, b in model[:-1]:
        output = jnp.dot(input, w) + b
        input = jnp.tanh(output)
    last_w, last_b = model[-1]
    last_output = jnp.dot(input, last_w) + last_b
    return last_output - logsumexp(last_output, axis=1, keepdims=True)

def loss(model, batch):
    """Calculates the loss. Labels are one-hot encoded here."""
    images, int_labels = batch["images"], batch["labels"]
    one_hot_labels = jax.nn.one_hot(int_labels.ravel(), num_classes=num_classes)
    predicted_labels = predict(model, images)
    return -jnp.mean(jnp.sum(predicted_labels * one_hot_labels, axis=1))

def accuracy(model, iterator, device):
    """Calculates the model's accuracy on the validation set."""
    correct_predictions = 0
    total_samples = 0
    for images, labels in iterator:
        images_flattened = images.reshape(images.shape[0], -1)
        predicted_class = jnp.argmax(predict(model, images_flattened), axis=1)
        correct_predictions += jnp.sum(predicted_class == labels.ravel())
        total_samples += len(labels)
    return correct_predictions / total_samples

Training Step Functions#

We define two versions of the update function to demonstrate both parallelization approaches.

  1. update: A standard function decorated with @jit. When used with data and models sharded across devices, JAX’s compiler automatically generates code to handle the distributed computation.

  2. update_parallel: This function is decorated with @pmap. This explicitly tells JAX to run the same function on multiple devices, each with its own slice of the data. Inside, lax.pmean is crucial—it calculates the mean of the gradients across all devices and broadcasts the result back to each one, ensuring the model weights stay synchronized.

@jit
def update(model, batch, learning_rate=0.001):
    """JIT-compiled update step for a single device or automatic sharding."""
    grads = grad(loss)(model, batch)
    updated_model = []
    for (w, b), (dw, db) in zip(model, grads):
        new_w = w - learning_rate * dw
        new_b = b - learning_rate * db
        updated_model.append((new_w, new_b))
    return updated_model

@partial(pmap, axis_name="data")
def update_parallel(model, batch, learning_rate=0.001):
    """pmapped update step with synchronized gradients for multi-GPU."""
    grads = grad(loss)(model, batch)
    grads = lax.pmean(grads, axis_name="data")
    updated_model = []
    for (w, b), (dw, db) in zip(model, grads):
        new_w = w - learning_rate * dw
        new_b = b - learning_rate * db
        updated_model.append((new_w, new_b))
    return updated_model

4. Training with Automatic Parallelization (Sharding)#

This approach is often simpler to implement. The ROCALJaxIterator was initialized with our sharding configuration. This tells the iterator to return a GlobalDeviceArray where the batch of data is already split and placed on the correct devices.

The training loop looks almost identical to single-GPU code. JAX handles the distributed computation behind the scenes when we call our JIT-compiled update function. For validation, we gather the sharded model parameters onto a single device using jax.device_put.

print("\n--- Starting Training with Automatic Parallelization (sharding) ---")

model_sharding = init_model()


for epoch in range(num_epochs):
    for it, (images, labels) in enumerate(training_iterator):
        images_flattened = images.reshape(images.shape[0], -1)
        batch = {
            "images": images_flattened,
            "labels": labels.reshape(-1, 1),
        }
        model_sharding = update(model_sharding, batch)

    validation_device = jax.devices()[0]
    model_on_one_device = jax.device_put(model_sharding, jax.devices()[0])
    test_acc = accuracy(model_on_one_device, validation_iterator, validation_device)

    print(f"Epoch {epoch}")
    print(f"Test set accuracy with sharding: {test_acc:.4f}")
    training_iterator.reset()
    validation_iterator.reset()

5. Training with pmap#

pmap provides a more explicit way to handle data parallelism.

Key steps:

  1. Replicate Model: The model parameters must be explicitly replicated on every device using jax.device_put_replicated.

  2. Reshape Data: The ROCALJaxIterator provides a single, global batch of data. For pmap, we must manually reshape this batch so that the first dimension corresponds to the number of devices. For example, a global batch of shape (256, 3072) on 4 GPUs becomes (4, 64, 3072).

  3. Call pmap function: We call update_parallel with the replicated model and the reshaped data batch.

  4. Gather Model for Validation: The updated model model_pmap is a replicated PyTree. To run validation, we take the parameters from the first device (e.g., jax.tree.map(lambda x: x[0], model_pmap)).

print("\n--- Starting Training with pmap ---")

if batch_size % device_count != 0:
    raise ValueError(
        f"Batch size ({batch_size}) must be divisible by device count ({device_count}) for pmap."
    )

model_cpu = init_model()
# Replicate model parameters across all devices for pmap
model_pmap = jax.device_put_replicated(model_cpu, jax.devices())


for epoch in range(num_epochs):
    # For pmap, we need to reshape the global batch from the iterator
    # back into a per-device format.
    for it, (images, labels) in enumerate(training_iterator):
        images_flattened = images.reshape(images.shape[0], -1)

        # Reshape the global batch (e.g., total size 2048) to (8 devices, 64 samples/device, features)
        images_pmapped = images_flattened.reshape((device_count, batch_size, -1))
        labels_pmapped = labels.reshape((device_count, batch_size))
        
        batch = {"images": images_pmapped, "labels": labels_pmapped}
        model_pmap = update_parallel(model_pmap, batch)

    validation_device = jax.devices()[0]
    model_on_one_device = jax.tree.map(lambda x: x[0], model_pmap)
    test_acc = accuracy(model_on_one_device, validation_iterator, validation_device)

    print(f"Epoch {epoch}")
    print(f"Test set accuracy with pmap: {test_acc:.4f}")
    training_iterator.reset()
    validation_iterator.reset()