jax.extend.backend.get_compile_options

jax.extend.backend.get_compile_options#

jax.extend.backend.get_compile_options(num_replicas, num_partitions, device_assignment=None, env_options_overrides=None, fdo_profile=None, detailed_logging=True, backend=None)[source]#

Returns the compile options to use, as derived from flag values.

Parameters:
  • num_replicas (int) – Number of replicas for which to compile.

  • num_partitions (int) – Number of partitions for which to compile.

  • device_assignment – Optional ndarray of jax devices indicating the assignment of logical replicas to physical devices (default inherited from xla_client.CompileOptions). Must be consistent with num_replicas and num_partitions.

  • env_options_overrides (dict[str, str] | None) – dict of additional options parsed by the compiler

  • fdo_profile (bytes | None) – Optional profile for feedback-directed optimization passed to XLA.

  • detailed_logging (bool) – Is this an “interesting” computation about which XLA would be wise to log compilation information?

  • backend (xc.Client | None) – the client, if available.

Return type:

xc.CompileOptions