jax.experimental.pallas.tpu.stateful_bits

Contents

jax.experimental.pallas.tpu.stateful_bits#

jax.experimental.pallas.tpu.stateful_bits(*args, **kwargs)[source]#

Sample uniform bits in the form of unsigned integers.

Parameters:
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, an unsigned integer dtype for the returned values (default uint64 if jax_enable_x64 is true, otherwise uint32).

  • out_sharding – Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.