rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator Class Reference#
rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator Class Reference
Inheritance diagram for rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator:
Public Member Functions | |
| def | __init__ (self, pipelines, sharding=None) |
| def | next (self) |
| def | __next__ (self) |
| def | reset (self) |
| def | place_output_with_device_put (self, individual_outputs) |
| def | place_output_with_sharding (self, individual_outputs) |
| def | __iter__ (self) |
| def | __len__ (self) |
| def | __del__ (self) |
Data Fields | |
| pipelines | |
| num_devices | |
| batch_size | |
| iterator_length | |
| last_batch_policy | |
| sharding | |
Detailed Description
Initializes the ROCAL JAX iterator.
Args:
pipelines (list of Pipeline objects): List of rocAL pipelines to use.
sharding (JAX sharding, optional): JAX sharding to use for placing outputs on devices. Defaults to None.
Constructor & Destructor Documentation
◆ __del__()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__del__ | ( | self | ) |
Releases the rocAL resources.
Member Function Documentation
◆ __iter__()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__iter__ | ( | self | ) |
Returns the iterator object.
Reimplemented in rocAL_pybind.amd.rocal.plugin.jax.ROCALPeekableIterator.
◆ __len__()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__len__ | ( | self | ) |
Returns the number of batches in the iterator.
◆ __next__()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.__next__ | ( | self | ) |
Returns the next batch of data.
Reimplemented in rocAL_pybind.amd.rocal.plugin.jax.ROCALPeekableIterator.
◆ next()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.next | ( | self | ) |
Returns the next batch of data.
◆ place_output_with_device_put()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.place_output_with_device_put | ( | self, | |
| individual_outputs | |||
| ) |
Builds sharded jax.Array with `jax.device_put_sharded` - compatible with pmapped JAX functions.
◆ place_output_with_sharding()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.place_output_with_sharding | ( | self, | |
| individual_outputs | |||
| ) |
Builds sharded jax.Array with `jax.make_array_from_single_device_arrays`- compatible with automatic parallelization with JAX.
◆ reset()
| def rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator.reset | ( | self | ) |
Resets the iterator for the next epoch.
The documentation for this class was generated from the following file:
- /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocal/checkouts/develop/rocAL_pybind/amd/rocal/plugin/jax.py