jax.experimental.random.stateful_rng

Contents

jax.experimental.random.stateful_rng#

jax.experimental.random.stateful_rng(seed=None, *, impl=None)[source]#

Experimental stateful RNG with implicitly-updated state.

This implements a stateful PRNG API similar to numpy.random.default_rng(). It is compatible with JAX transformations like jit() and others, with a few exceptions mentioned in the Notes below.

Note

This stateful PRNG API is a convenience wrapper around JAX’s classic stateless, explicitly updated PRNG, described in jax.random. For performance-critical applications, it is recommended to use jax.random.key() with explicit random state semantics.

For a discussion of design considerations for this API, refer to JEP 28845: Stateful Randomness in JAX.

Parameters:
  • seed (ArrayLike | None) – an optional 64- or 32-bit integer used as the value of the key. This must be specified if the generator is instantiated within transformed code; when used at the top level of the program, it may be omitted in which case the RNG will be seeded using the default NumPy seeding.

  • impl (random.PRNGSpecDesc | None) – optional string specifying the PRNG implementation (e.g. 'threefry2x32')

Returns:

A StatefulPRNG object, with methods for generating random values.

Return type:

StatefulPRNG

Notes

The StatefulPRNG object created by this method uses Ref() objects to allow implicit updates of state, and thus inherits some of its limitiations. For example:

  • StatefulPRNG objects cannot be among the return values of functions wrapped in JIT or other JAX transformations. This means in particular they cannot be used as carry values for jax.lax.scan(), jax.lax.while_loop(), and other JAX control flow.

  • StatefulPRNG objects cannot be used together with jax.checkpoint() or jax.remat(); in these cases it’s best to use the StatefulPRNG.key() method to produce a standard JAX PRNG key.

Examples

>>> from jax.experimental import random
>>> rng = random.stateful_rng(42)

Repeated draws implicitly update the key:

>>> rng.uniform()
Array(0.5302608, dtype=float32)
>>> rng.uniform()
Array(0.72766423, dtype=float32)

This also works under transformations like jax.jit():

>>> import jax
>>> jit_uniform = jax.jit(rng.uniform)
>>> jit_uniform()
Array(0.6672406, dtype=float32)
>>> jit_uniform()
Array(0.3890121, dtype=float32)

Keys can be generated directly if desired:

>>> rng.key()
Array((), dtype=key<fry>) overlaying:
[2954079971 3276725750]
>>> rng.key()
Array((), dtype=key<fry>) overlaying:
[2765691542  824333390]