Using rocAL with JAX for training

Using rocAL with JAX for training#

rocAL improves machine learning (ML) pipeline efficiency by preprocessing data and parallelizing data loading.

JAX iterators are provided as plugins to separate data loading from training.

You’ll need a rocAL JAX Docker container to run JAX training with rocAL.

To use rocAL with JAX, import the rocAL JAX plugin:

from amd.rocal.plugin.jax import ROCALJaxIterator

Get the number of available devices using jax.device_count(). Set up a training pipeline that partitions the training data and run the pipeline on each device using ROCALJaxIterator.

A Jupyter Notebook is available as an example of using JAX with rocAL.