jax.experimental.pallas.tpu.make_async_remote_copy#
- jax.experimental.pallas.tpu.make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type=DeviceIdType.MESH)[source]#
Creates a description of a remote copy operation.
Copies data from src_ref on the current device to dst_ref on the device specified by device_id. Both semaphores should be waited on using the descriptor on both source and target devices.
Note that device_id can also refer to the current device.
- Parameters:
src_ref – The source Reference.
dst_ref – The destination Reference.
send_sem – The semaphore on the source device.
recv_sem – The semaphore on the destination device.
device_id (MultiDimDeviceId | IntDeviceId | None) – The device id of the destination device. It could be a tuple, or a dictionary specifying the communication axis and destination index.
device_id_type (primitives.DeviceIdType) – The type of the device id.
- Returns:
An AsyncCopyDescriptor.
- Return type:
AsyncCopyDescriptor