jax.experimental.pallas.tpu.core_barrier#

jax.experimental.pallas.tpu.core_barrier(sem, *, core_axis_name)[source]#

Synchronizes all cores in a given axis.

Parameters:

core_axis_name (str)