jax.random.gumbel

Contents

jax.random.gumbel#

jax.random.gumbel(key, shape=(), dtype=None, mode=None, *, out_sharding=None)[source]#

Sample Gumbel random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = e^{-(x + e^{-x})}\]
Parameters:
  • key (ArrayLike) – a PRNG key used as the random key.

  • shape (Shape) – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype (DTypeLikeFloat | None) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • mode (str | None) – optional, “high” or “low” for how many bits to use when sampling. The default is determined by the use_high_dynamic_range_gumbel config, which defaults to “low”. When drawing float32 samples, with mode=”low” the uniform resolution is such that the largest possible gumbel logit is ~16; with mode=”high” this is increased to ~32, at approximately double the computational cost.

  • out_sharding (NamedSharding | P | None) – Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

Return type:

Array