jax.debug.print#
- jax.debug.print(fmt=None, *args, ordered=False, partitioned=False, skip_format_check=False, _use_logging=False, **kwargs)[source]#
Prints values and works in staged out JAX functions.
This function does not work with f-strings because formatting is delayed. So instead of
jax.debug.print(f"hello {bar}"), writejax.debug.print("hello {bar}", bar=bar).jax.debug.printsupports two ways of being called:Two-call form (Recommended):
jax.debug.print(ordered=True)("hello {x}", x=42)Options are passed in the first call. The format string and arguments are passed in the second call. No option arguments are accepted in the second call.Single-call form:
jax.debug.print("hello {x}", x=42, ordered=True)(Soft deprecated) Mixing ordered and partitioned options with printkwargsis soft deprecated.
- Parameters:
fmt (str | None) – A format string, e.g.
"hello {x}", that will be used to format input arguments, likestr.format. See the Python docs on string formatting and format string syntax.*args – A list of positional arguments to be formatted, as if passed to
fmt.format.ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this
jax.debug.printw.r.t. other orderedjax.debug.printcalls.partitioned (bool) – If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first.
skip_format_check (bool) – If True, the format string is not checked. This is useful when using the function from inside a Pallas TPU kernel, where scalars args will be printed after the format string.
**kwargs – Additional keyword arguments to be formatted, as if passed to
fmt.format._use_logging (bool)
- Return type:
Callable[…, None] | None