jax.random.permutation

Contents

jax.random.permutation#

jax.random.permutation(key, x, axis=0, independent=False, *, out_sharding=None)[source]#

Returns a randomly permuted array or range.

Parameters:
  • key (ArrayLike) – a PRNG key used as the random key.

  • x (int | ArrayLike) – int or array. If x is an integer, randomly shuffle np.arange(x). If x is an array, randomly shuffle its elements.

  • axis (int) – int, optional. The axis which x is shuffled along. Default is 0.

  • independent (bool) – bool, optional. If set to True, each individual vector along the given axis is shuffled independently. Default is False.

  • out_sharding (NamedSharding | P | None) – Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A shuffled version of x or array range

Return type:

Array