jax.experimental.pallas.tpu.stateful_bernoulli

jax.experimental.pallas.tpu.stateful_bernoulli#

jax.experimental.pallas.tpu.stateful_bernoulli(*args, **kwargs)[source]#

Sample Bernoulli random values with given shape and mean.

The values are distributed according to the probability mass function:

\[f(k; p) = p^k(1 - p)^{1 - k}\]

where \(k \in \{0, 1\}\) and \(0 \le p \le 1\).

Parameters:
  • p – optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with shape. Default 0.5.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with p.shape. The default (None) produces a result shape equal to p.shape.

  • mode – optional, “high” or “low” for how many bits to use when sampling. default=’low’. Set to “high” for correct sampling at small values of p. When sampling in float32, bernoulli samples with mode=’low’ produce incorrect results for p < ~1E-7. mode=”high” approximately doubles the cost of sampling.

  • out_sharding – Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with boolean dtype and shape given by shape if shape is not None, or else p.shape.