jax.debug.print

Contents

jax.debug.print#

jax.debug.print(fmt=None, *args, ordered=False, partitioned=False, skip_format_check=False, _use_logging=False, **kwargs)[source]#

Prints values and works in staged out JAX functions.

This function does not work with f-strings because formatting is delayed. So instead of jax.debug.print(f"hello {bar}"), write jax.debug.print("hello {bar}", bar=bar).

jax.debug.print supports two ways of being called:

  1. Two-call form (Recommended): jax.debug.print(ordered=True)("hello {x}", x=42) Options are passed in the first call. The format string and arguments are passed in the second call. No option arguments are accepted in the second call.

  2. Single-call form: jax.debug.print("hello {x}", x=42, ordered=True) (Soft deprecated) Mixing ordered and partitioned options with print kwargs is soft deprecated.

Parameters:
  • fmt (str | None) – A format string, e.g. "hello {x}", that will be used to format input arguments, like str.format. See the Python docs on string formatting and format string syntax.

  • *args – A list of positional arguments to be formatted, as if passed to fmt.format.

  • ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this jax.debug.print w.r.t. other ordered jax.debug.print calls.

  • partitioned (bool) – If True, then print local shards only; this option avoids an all-gather of the operands. If False, print with logical operands; this option requires an all-gather of operands first.

  • skip_format_check (bool) – If True, the format string is not checked. This is useful when using the function from inside a Pallas TPU kernel, where scalars args will be printed after the format string.

  • **kwargs – Additional keyword arguments to be formatted, as if passed to fmt.format.

  • _use_logging (bool)

Return type:

Callable[…, None] | None