jax.extend.core.AbstractToken

jax.extend.core.AbstractToken#

class jax.extend.core.AbstractToken[source]#
__init__()#

Methods

__init__()

at_least_vspace()

dec_rank(size, spec)

inc_rank(size, spec)

leading_axis_spec()

lo_ty()

lo_ty_qdd(qdd)

lower_val2(hi_val)

normalize()

raise_val2(lo_vals_ft)

shard(mesh, manual_axes, check_vma, spec)

str_short([short_dtypes, mesh_axis_types])

strip_weak_type()

to_ct_aval()

to_tangent_aval()

unshard(mesh, check_vma, spec)

update(**kwargs)

update_manual_axis_type(mat)

update_weak_type(weak_type)

vspace_add(x, y)

Attributes

has_qdd

is_high