jax.experimental.pallas.tpu.stateful_bits#
- jax.experimental.pallas.tpu.stateful_bits(*args, **kwargs)[source]#
Sample uniform bits in the form of unsigned integers.
- Parameters:
shape – optional, a tuple of nonnegative integers representing the result shape. Default
().dtype – optional, an unsigned integer dtype for the returned values (default
uint64ifjax_enable_x64is true, otherwiseuint32).
- Returns:
A random array with the specified shape and dtype.