jax.ShapeDtypeStruct#
- class jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False, manual_axis_type=None, is_ref=False)#
A container for the shape, dtype, and other static attributes of an array.
ShapeDtypeStructis often used in conjunction withjax.eval_shape().- Parameters:
shape (Any) – a sequence of integers representing an array shape
dtype (Any) – a dtype-like object
sharding – (optional) a
jax.Shardingobjectweak_type (Any)
manual_axis_type (Any)
is_ref (Any)
- __init__(shape, dtype, *, sharding=None, weak_type=False, manual_axis_type=None, is_ref=False)[source]#
Methods
__init__(shape, dtype, *[, sharding, ...])update(**kwargs)Attributes
shapedtypeweak_typemanual_axis_typeis_refformatndimshardingsize