jax.extend.core module

jax.extend.core module#

AbstractToken()

CallPrimitive(name)

ClosedJaxpr(jaxpr, consts)

DebugInfo(traced_for, func_src_info, ...)

Debugging info about a func, its arguments, and results.

DropVar(aval)

Effect()

A generic side-effect.

Effects

A set is a finite, iterable container.

InconclusiveDimensionOperation

Raised when we cannot conclusively compute with symbolic dimensions.

Jaxpr(constvars, invars, outvars, eqns[, ...])

JaxprEqn(invars, outvars, primitive, params, ...)

JaxprTypeError

Literal(val, aval)

Primitive(name)

Token(buf)

TraceTag()

Var(aval[, initial_qdd, final_qdd])

array_types

set() -> new empty set object set(iterable) -> new set object

call_impl(f, *args, **params)

check_jaxpr(jaxpr)

Checks well-formedness of a jaxpr.

concrete_or_error(force, val[, context])

Like force(val), but gives the context in the error message.

find_top_trace(_)

gensym()

get_opaque_trace_state([convention])

jaxpr_as_fun

jaxprs_in_params(params)

mapped_aval(size, axis, aval)

new_jaxpr_eqn(invars, outvars, primitive, ...)

no_effects

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

nonempty_axis_env_DO_NOT_USE()

primal_dtype_to_tangent_dtype(primal_dtype)

primitives

set_current_trace

alias of SetCurrentTraceContextManager

subjaxprs(jaxpr)

Generator for all subjaxprs found in the params of jaxpr.eqns.

take_current_trace

alias of TakeCurrentTraceContextManager

unmapped_aval(size, axis, aval[, ...])

unsafe_am_i_under_a_jit_DO_NOT_USE()

unsafe_am_i_under_a_vmap_DO_NOT_USE()

unsafe_get_axis_names_DO_NOT_USE()

valid_jaxtype(x)