jax.experimental.pallas.tpu.async_remote_copy

jax.experimental.pallas.tpu.async_remote_copy#

jax.experimental.pallas.tpu.async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type=DeviceIdType.MESH)[source]#

Issues a remote DMA copying from src_ref to dst_ref.

Parameters:

device_id_type (primitives.DeviceIdType)

Return type:

AsyncCopyDescriptor