jax.experimental.pallas.tpu.make_async_remote_copy#

jax.experimental.pallas.tpu.make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type=DeviceIdType.MESH)[source]#

Creates a description of a remote copy operation.

Copies data from src_ref on the current device to dst_ref on the device specified by device_id. Both semaphores should be waited on using the descriptor on both source and target devices.

Note that device_id can also refer to the current device.

Parameters:
  • src_ref – The source Reference.

  • dst_ref – The destination Reference.

  • send_sem – The semaphore on the source device.

  • recv_sem – The semaphore on the destination device.

  • device_id (MultiDimDeviceId | IntDeviceId | None) – The device id of the destination device. It could be a tuple, or a dictionary specifying the communication axis and destination index.

  • device_id_type (primitives.DeviceIdType) – The type of the device id.

Returns:

An AsyncCopyDescriptor.

Return type:

AsyncCopyDescriptor