/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocal/checkouts/develop/rocAL_pybind/amd/rocal/plugin/jax.py File Reference

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocal/checkouts/develop/rocAL_pybind/amd/rocal/plugin/jax.py File Reference#

rocAL: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocal/checkouts/develop/rocAL_pybind/amd/rocal/plugin/jax.py File Reference
jax.py File Reference

File containing iterators and functions for JAX framework. More...

Data Structures

class  rocAL_pybind.amd.rocal.plugin.jax.ROCALJaxIterator
 
class  rocAL_pybind.amd.rocal.plugin.jax.ROCALPeekableIterator
 

Functions

def rocAL_pybind.amd.rocal.plugin.jax.convert_to_jax_array (array)
 
def rocAL_pybind.amd.rocal.plugin.jax.get_spec_for_array (jax_array)
 

Variables

bool rocAL_pybind.amd.rocal.plugin.jax.CLU_FOUND = True
 

Detailed Description

File containing iterators and functions for JAX framework.

Function Documentation

◆ convert_to_jax_array()

def rocAL_pybind.amd.rocal.plugin.jax.convert_to_jax_array (   array)
Converts input DLPack tensor to JAX array.

Args:
    array (DLPack tensor):
        array to be converted to JAX array.

Returns:
    jax.Array: JAX array with the same values and device as input array.

◆ get_spec_for_array()

def rocAL_pybind.amd.rocal.plugin.jax.get_spec_for_array (   jax_array)
Returns the ArraySpec for a given JAX array.

Args:
    jax_array (jax.Array): The JAX array.

Returns:
    ArraySpec: The specification of the array.