jax.experimental.pallas.empty#

jax.experimental.pallas.empty(shape, dtype, *, out_sharding=None)[source]#

Create an empty array of possibly uninitialized values.

This initialization is backend dependent.

Parameters:
  • shape – int or sequence of ints specifying the shape of the created array.

  • dtype – dtype for the created array.

  • out_sharding – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details).

Returns:

Uninitialized array of the specified shape, dtype, and sharding.

Examples

>>> jnp.empty(3, jnp.float32)  
Array([-5.7326739e+29 -7.7323739e+29 -3.14159256e-29], dtype=float32)