jax.experimental.random.StatefulPRNG#

class jax.experimental.random.StatefulPRNG(_base_key, _counter)[source]#

Stateful JAX random generator.

This should be instantiated using the jax.experimental.random.stateful_rng() function.

Parameters:
_base_key#

a typed JAX PRNG key object (see jax.random.key()).

Type:

Array

_counter#

a scalar integer wrapped in a jax.Ref.

Type:

core.Ref

Examples:

>>> from jax.experimental import random
>>> rng = random.stateful_rng(42)
>>> rng
StatefulPRNG(_base_key=Array((), dtype=key<fry>) overlaying:
[ 0 42], _counter=Ref(0, dtype=int32, weak_type=True))
__init__(_base_key, _counter)#
Parameters:
Return type:

None

Methods

__init__(_base_key, _counter)

integers(low[, high, size, dtype])

Draw pseudorandom integers.

key([shape])

Generate a new JAX PRNGKey, updating the internal state.

normal([loc, scale, size, dtype])

Draw normally-distributed pseudorandom values.

random([size, dtype])

Return random floats in the half-open interval [0.0, 1.0).

spawn(n_children)

Create a list of independent child generators.

split(num)

Create independent child generators suitable for use in jax.vmap().

uniform([low, high, size, dtype])

Draw uniformly distributed pseudorandom values.