jax.numpy.array#

jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0, *, device=None, out_sharding=None)[source]#

Convert an object to a JAX array.

JAX implementation of numpy.array().

Parameters:
  • object (Any) – an object that is convertible to an array. This includes JAX arrays, NumPy arrays, Python scalars, Python collections like lists and tuples, objects with a __jax_array__ method, and objects supporting the Python buffer protocol.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – optionally specify the dtype of the output array. If not specified it will be inferred from the input.

  • copy (bool) – specify whether to force a copy of the input. Default: True.

  • order (str | None) – not implemented in JAX

  • ndmin (int) – integer specifying the minimum number of dimensions in the output array.

  • device (Device | Sharding | None) – optional Device or Sharding to which the created array will be committed.

  • out_sharding (NamedSharding | PartitionSpec | None) – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying both out_sharding and device will result in an error.

Returns:

A JAX array constructed from the input.

Return type:

Array

See also

Examples

Constructing JAX arrays from Python scalars:

>>> jnp.array(True)
Array(True, dtype=bool)
>>> jnp.array(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.array(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.array(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)

Constructing JAX arrays from Python collections:

>>> jnp.array([1, 2, 3])  # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.array([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.array(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)

Constructing JAX arrays from NumPy arrays:

>>> jnp.array(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

Constructing a JAX array via the Python buffer interface, using Python’s built-in array module.

>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.array(pybuffer)
Array([2, 3, 5, 7], dtype=int32)