jax.extend.backend module#

backends()

backend_xla_version([platform])

Returns the XLA version of the backend.

clear_backends()

Clear all backend clients so that new backend clients can be created later.

get_backend([platform])

get_compile_options(num_replicas, num_partitions)

Returns the compile options to use, as derived from flag values.

get_default_device()

ifrt_proxy

register_backend_cache(cache, for_what)

Registers a cache with JAX's cache management.

register_backend_factory(name, factory, *[, ...])