jax.extend.core.DebugInfo

jax.extend.core.DebugInfo#

class jax.extend.core.DebugInfo(traced_for, func_src_info, arg_names, result_paths)[source]#

Debugging info about a func, its arguments, and results.

Parameters:
__init__()#

Methods

__init__()

assert_arg_names(expected_count)

assert_result_paths(expected_count)

count(value, /)

Return number of occurrences of value.

filter_arg_names(keep)

Keep only the arg_names for which keep is True.

filter_result_paths(keep)

Keep only the result_paths for which keep is True.

index(value[, start, stop])

Return first index of value.

replace_func_name(name)

resolve_result_paths()

Return a debug info with resolved result paths.

safe_arg_names(expected_count)

Get the arg_names with a safety check.

safe_result_paths(expected_count)

Get the result paths with a safety check.

set_result_paths(ans)

with_unknown_names()

Attributes

arg_names

The paths of the flattened non-static argnames, for example ('x', 'dict_arg["a"]', ...).

func_filename

func_lineno

func_name

func_src_info

e.g. f'{fun.__name__} at {filename}:{lineno}' or '{fun.__name__}' if we have no source location information.

result_paths

The paths to the flattened results, e.g., ('result[0]', result[1]) for a function that returns a tuple of arrays, or (result,) for a function that returns a single array.

traced_for

Alias for field number 0