jax.ref.AbstractRef

jax.ref.AbstractRef#

class jax.ref.AbstractRef(inner_aval, memory_space=None, kind=None)[source]#

Abstract mutable array reference.

Refer to the Ref guide for more information.

Parameters:
  • inner_aval (core.AbstractValue)

  • memory_space (Any)

  • kind (Any)

__init__(inner_aval, memory_space=None, kind=None)[source]#
Parameters:
  • inner_aval (core.AbstractValue)

  • memory_space (Any)

  • kind (Any)

Methods

__init__(inner_aval[, memory_space, kind])

at_least_vspace()

dec_rank(size, spec)

inc_rank(size, spec)

leading_axis_spec()

lo_ty()

lo_ty_qdd(qdd)

lower_val(ref)

lower_val2(hi_val)

normalize()

raise_val(*vals)

raise_val2(lo_vals_ft)

shard(mesh, manual_axes, check_vma, spec)

str_short([short_dtypes, mesh_axis_types])

strip_weak_type()

to_ct_aval()

to_tangent_aval()

unshard(mesh, check_vma, spec)

update([inner_aval, memory_space, kind])

update_manual_axis_type(mat)

update_weak_type(weak_type)

vspace_add(x, y)

Attributes

inner_aval

memory_space

kind

T

addupdate

at

bitcast

dtype

get

has_qdd

is_high

manual_axis_type

mat

ndim

reshape

set

shape

sharding

size

swap

transpose

weak_type