jax.random.randint#

jax.random.randint(key, shape, minval, maxval, dtype=None, *, out_sharding=None)[source]#

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key (ArrayLike) – a PRNG key used as the random key.

  • shape (Shape) – a tuple of nonnegative integers representing the shape.

  • minval (IntegerArray) – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.

  • maxval (IntegerArray) – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.

  • dtype (DTypeLikeInt | None) – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

Returns:

A random array with the specified shape and dtype.

Return type:

Array

Note

randint() uses a modulus-based computation that is known to produce slightly biased values in some cases. The magnitude of the bias scales as (maxval - minval) * ((2 ** nbits ) % (maxval - minval)) / 2 ** nbits: in words, the bias goes to zero when (maxval - minval) is a power of 2, and otherwise the bias will be small whenever (maxval - minval) is small compared to the range of the sampled type.

To reduce this bias, 8-bit and 16-bit values will always be sampled at 32-bit and then cast to the requested type. If you find yourself sampling values for which this bias may be problematic, a possible alternative is to sample via uniform:

def randint_via_uniform(key, shape, minval, maxval, dtype):
  u = jax.random.uniform(key, shape, minval=minval - 0.5, maxval=maxval - 0.5)
  return u.round().astype(dtype)

But keep in mind this method has its own biases due to floating point rounding errors, and in particular there may be some integers in the range [minval, maxval) that are impossible to produce with this approach.