jax.ShapeDtypeStruct

jax.ShapeDtypeStruct#

class jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False, manual_axis_type=None, is_ref=False)#

A container for the shape, dtype, and other static attributes of an array.

ShapeDtypeStruct is often used in conjunction with jax.eval_shape().

Parameters:
  • shape (Any) – a sequence of integers representing an array shape

  • dtype (Any) – a dtype-like object

  • sharding – (optional) a jax.Sharding object

  • weak_type (Any)

  • manual_axis_type (Any)

  • is_ref (Any)

__init__(shape, dtype, *, sharding=None, weak_type=False, manual_axis_type=None, is_ref=False)[source]#

Methods

__init__(shape, dtype, *[, sharding, ...])

update(**kwargs)

Attributes

shape

dtype

weak_type

manual_axis_type

is_ref

format

ndim

sharding

size