jax.extend.lowering.LoweringRuleContext#
- class jax.extend.lowering.LoweringRuleContext(module_context, name_stack, traceback, primitive, avals_in, avals_out, tokens_in, tokens_out, const_lowering, axis_size_env=None, dim_var_values=(), jaxpr_eqn_ctx=None, platforms=None)[source]#
Per-rule context information for MLIR lowering.
- Parameters:
module_context (ModuleContext)
name_stack (source_info_util.NameStack)
traceback (xc.Traceback | None)
primitive (core.Primitive | None)
avals_in (Sequence[core.AbstractValue])
avals_out (Any)
tokens_in (TokenSet)
tokens_out (TokenSet | None)
const_lowering (dict[tuple[int, core.AbstractValue], IrValues])
dim_var_values (Sequence[ir.Value])
jaxpr_eqn_ctx (core.JaxprEqnContext | None)
platforms (Sequence[str] | None)
- __init__(module_context, name_stack, traceback, primitive, avals_in, avals_out, tokens_in, tokens_out, const_lowering, axis_size_env=None, dim_var_values=(), jaxpr_eqn_ctx=None, platforms=None)#
- Parameters:
module_context (ModuleContext)
name_stack (source_info_util.NameStack)
traceback (xc.Traceback | None)
primitive (core.Primitive | None)
avals_in (Sequence[core.AbstractValue])
avals_out (Any)
tokens_in (TokenSet)
tokens_out (TokenSet | None)
const_lowering (dict[tuple[int, core.AbstractValue], IrValues])
dim_var_values (Sequence[ir.Value])
jaxpr_eqn_ctx (core.JaxprEqnContext | None)
platforms (Sequence[str] | None)
- Return type:
None
Methods
__init__(module_context, name_stack, ...[, ...])is_forward_compat()Returns true if the lowering parameters are in forward compatibility mode.
replace(**kw)set_tokens_out(tokens_out)Attributes
axis_size_envdim_var_valuesjaxpr_eqn_ctxplatformsmodule_contextname_stacktracebackprimitiveavals_inavals_outtokens_intokens_outconst_lowering