jax.random.wrap_key_data#
- jax.random.wrap_key_data(key_bits_array, *, impl=None, dtype=None)[source]#
Wrap an array of key data bits into a PRNG key array.
- Parameters:
key_bits_array (Array) – a
uint32array with trailing shape corresponding to the key shape of the PRNG implementation specified byimpl.impl (PRNGSpecDesc | None) – optional, specifies a PRNG implementation, as in
random.key.dtype (KeyDTypeLike | None) – optional dtype or string name specifying the PRNG implementation (e.g.
jax.random.key_dtype('threefry2x32')or'threefry2x32').
- Returns:
- A PRNG key array, whose dtype is a subdtype of
jax.dtypes.prng_key corresponding to
impl, and whose shape equals the leading shape ofkey_bits_array.shapeup to the key bit dimensions.
- A PRNG key array, whose dtype is a subdtype of
Examples
Construct a key, and extract its data and dtype:
>>> import jax >>> key = jax.random.key(42) >>> data = jax.random.key_data(key) >>> dtype = key.dtype
Reconstruct an equivalent key with
wrap_key_data():>>> new_key = jax.random.wrap_key_data(data, dtype=dtype) >>> key == new_key Array(True, dtype=bool)