jax.random.randint

Contents

jax.random.randint#

jax.random.randint(key, shape, minval, maxval, dtype=None, *, out_sharding=None)[source]#

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key (ArrayLike) – a PRNG key used as the random key.

  • shape (Shape) – a tuple of nonnegative integers representing the shape.

  • minval (IntegerArray) – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.

  • maxval (IntegerArray) – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.

  • dtype (DTypeLikeInt | None) – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

  • out_sharding (NamedSharding | P | None) – Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

Return type:

Array

Note

randint() uses a modulus-based computation that is known to produce slightly biased values in some cases. The magnitude of the bias scales as (maxval - minval) * ((2 ** nbits ) % (maxval - minval)) / 2 ** nbits: in words, the bias goes to zero when (maxval - minval) is a power of 2, and otherwise the bias will be small whenever (maxval - minval) is small compared to the range of the sampled type.

To reduce this bias, 8-bit and 16-bit values will always be sampled at 32-bit and then cast to the requested type. If you find yourself sampling values for which this bias may be problematic, a possible alternative is to sample via uniform:

def randint_via_uniform(key, shape, minval, maxval, dtype):
  u = jax.random.uniform(key, shape, minval=minval - 0.5, maxval=maxval - 0.5)
  return u.round().astype(dtype)

But keep in mind this method has its own biases due to floating point rounding errors, and in particular there may be some integers in the range [minval, maxval) that are impossible to produce with this approach.