jax.experimental.pallas.tpu.run_on_first_core#

jax.experimental.pallas.tpu.run_on_first_core(core_axis_name)[source]#

Runs a function on the first core in a given axis.

Parameters:

core_axis_name (str)