jax.random.categorical

Contents

jax.random.categorical#

jax.random.categorical(key, logits, axis=-1, shape=None, replace=True, mode=None, *, out_sharding=None)[source]#

Sample random values from categorical distributions.

Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses the Gumbel top-k trick. See [1] for reference.

Parameters:
  • key (ArrayLike) – a PRNG key used as the random key.

  • logits (RealArray) – Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.

  • axis (int) – Axis along which logits belong to the same categorical distribution.

  • shape (Shape | None) – Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with np.delete(logits.shape, axis). The default (None) produces a result shape equal to np.delete(logits.shape, axis).

  • replace (bool) – If True (default), perform sampling with replacement. If False, perform sampling without replacement.

  • mode (str | None) – optional, “high” or “low” for how many bits to use in the gumbel sampler. The default is determined by the use_high_dynamic_range_gumbel config, which defaults to “low”. With mode=”low”, in float32 sampling will be biased for events with probability less than about 1E-7; with mode=”high” this limit is pushed down to about 1E-14. mode=”high” approximately doubles the cost of sampling.

  • out_sharding (NamedSharding | P | None) – Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with int dtype and shape given by shape if shape is not None, or else np.delete(logits.shape, axis).

Return type:

Array

References