jax.experimental.random.stateful_rng#
- jax.experimental.random.stateful_rng(seed=None, *, impl=None)[source]#
Experimental stateful RNG with implicitly-updated state.
This implements a stateful PRNG API similar to
numpy.random.default_rng(). It is compatible with JAX transformations likejit()and others, with a few exceptions mentioned in the Notes below.Note
This stateful PRNG API is a convenience wrapper around JAX’s classic stateless, explicitly updated PRNG, described in
jax.random. For performance-critical applications, it is recommended to usejax.random.key()with explicit random state semantics.For a discussion of design considerations for this API, refer to JEP 28845: Stateful Randomness in JAX.
- Parameters:
seed (ArrayLike | None) – an optional 64- or 32-bit integer used as the value of the key. This must be specified if the generator is instantiated within transformed code; when used at the top level of the program, it may be omitted in which case the RNG will be seeded using the default NumPy seeding.
impl (random.PRNGSpecDesc | None) – optional string specifying the PRNG implementation (e.g.
'threefry2x32')
- Returns:
A
StatefulPRNGobject, with methods for generating random values.- Return type:
Notes
The
StatefulPRNGobject created by this method usesRef()objects to allow implicit updates of state, and thus inherits some of its limitiations. For example:StatefulPRNGobjects cannot be among the return values of functions wrapped in JIT or other JAX transformations. This means in particular they cannot be used as carry values forjax.lax.scan(),jax.lax.while_loop(), and other JAX control flow.StatefulPRNGobjects cannot be used together withjax.checkpoint()orjax.remat(); in these cases it’s best to use theStatefulPRNG.key()method to produce a standard JAX PRNG key.
Examples
>>> from jax.experimental import random >>> rng = random.stateful_rng(42)
Repeated draws implicitly update the key:
>>> rng.uniform() Array(0.5302608, dtype=float32) >>> rng.uniform() Array(0.72766423, dtype=float32)
This also works under transformations like
jax.jit():>>> import jax >>> jit_uniform = jax.jit(rng.uniform) >>> jit_uniform() Array(0.6672406, dtype=float32) >>> jit_uniform() Array(0.3890121, dtype=float32)
Keys can be generated directly if desired:
>>> rng.key() Array((), dtype=key<fry>) overlaying: [2954079971 3276725750] >>> rng.key() Array((), dtype=key<fry>) overlaying: [2765691542 824333390]