rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator Class Reference

rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator Class Reference#

rocAL: rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator Class Reference
rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator Class Reference
Inheritance diagram for rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator:
rocAL_pybind.amd.rocal.plugin.jax.ROCALPeekableIterator

Public Member Functions

def __init__ (self, pipelines, sharding=None)
 
def next (self)
 
def __next__ (self)
 
def reset (self)
 
def place_output_with_device_put (self, individual_outputs)
 
def place_output_with_sharding (self, individual_outputs)
 
def __iter__ (self)
 
def __len__ (self)
 
def __del__ (self)
 

Data Fields

 pipelines
 
 num_devices
 
 batch_size
 
 iterator_length
 
 last_batch_policy
 
 sharding
 

Detailed Description

Initializes the ROCAL JAX iterator.

Args:
    pipelines (list of Pipeline objects): List of rocAL pipelines to use.
    sharding (JAX sharding, optional): JAX sharding to use for placing outputs on devices. Defaults to None.

Constructor & Destructor Documentation

◆ __del__()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__del__ (   self)
Releases the rocAL resources.

Member Function Documentation

◆ __iter__()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__iter__ (   self)
Returns the iterator object.

Reimplemented in rocAL_pybind.amd.rocal.plugin.jax.ROCALPeekableIterator.

◆ __len__()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__len__ (   self)
Returns the number of batches in the iterator.

◆ __next__()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__next__ (   self)
Returns the next batch of data.

Reimplemented in rocAL_pybind.amd.rocal.plugin.jax.ROCALPeekableIterator.

◆ next()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.next (   self)
Returns the next batch of data.

◆ place_output_with_device_put()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.place_output_with_device_put (   self,
  individual_outputs 
)
Builds sharded jax.Array with `jax.device_put_sharded` - compatible
with pmapped JAX functions.

◆ place_output_with_sharding()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.place_output_with_sharding (   self,
  individual_outputs 
)
Builds sharded jax.Array with `jax.make_array_from_single_device_arrays`-
compatible with automatic parallelization with JAX.

◆ reset()

def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.reset (   self)
Resets the iterator for the next epoch.

The documentation for this class was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocal/checkouts/develop/rocAL_pybind/amd/rocal/plugin/jax.py