jax.experimental.pallas.tpu.sync_copy

Contents

jax.experimental.pallas.tpu.sync_copy#

jax.experimental.pallas.tpu.sync_copy(src_ref, dst_ref, *, add=False)[source]#

Synchronously copies a PyTree of refs to another PyTree of refs.

Parameters:

add (bool)

Return type:

None