/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocal/checkouts/develop/rocAL_pybind/amd/rocal/plugin/jax.py File Reference#
jax.py File Reference
File containing iterators and functions for JAX framework. More...
Data Structures | |
class | rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator |
class | rocAL_pybind.amd.rocal.plugin.jax.ROCALPeekableIterator |
Functions | |
def | rocAL_pybind.amd.rocal.plugin.jax.convert_to_jax_array (array) |
def | rocAL_pybind.amd.rocal.plugin.jax.get_spec_for_array (jax_array) |
Variables | |
bool | rocAL_pybind.amd.rocal.plugin.jax.CLU_FOUND = True |
Detailed Description
File containing iterators and functions for JAX framework.
Function Documentation
◆ convert_to_jax_array()
def rocAL_pybind.amd.rocal.plugin.jax.convert_to_jax_array | ( | array | ) |
Converts input DLPack tensor to JAX array. Args: array (DLPack tensor): array to be converted to JAX array. Returns: jax.Array: JAX array with the same values and device as input array.
◆ get_spec_for_array()
def rocAL_pybind.amd.rocal.plugin.jax.get_spec_for_array | ( | jax_array | ) |
Returns the ArraySpec for a given JAX array. Args: jax_array (jax.Array): The JAX array. Returns: ArraySpec: The specification of the array.