jax.experimental.random.StatefulPRNG#
- class jax.experimental.random.StatefulPRNG(_base_key, _counter)[source]#
Stateful JAX random generator.
This should be instantiated using the
jax.experimental.random.stateful_rng()function.- _base_key#
a typed JAX PRNG key object (see
jax.random.key()).- Type:
Examples:
>>> from jax.experimental import random >>> rng = random.stateful_rng(42) >>> rng StatefulPRNG(_base_key=Array((), dtype=key<fry>) overlaying: [ 0 42], _counter=Ref(0, dtype=int32, weak_type=True))
Methods
__init__(_base_key, _counter)integers(low[, high, size, dtype])Draw pseudorandom integers.
key([shape])Generate a new JAX PRNGKey, updating the internal state.
normal([loc, scale, size, dtype])Draw normally-distributed pseudorandom values.
random([size, dtype])Return random floats in the half-open interval [0.0, 1.0).
spawn(n_children)Create a list of independent child generators.
split(num)Create independent child generators suitable for use in
jax.vmap().uniform([low, high, size, dtype])Draw uniformly distributed pseudorandom values.