jax.experimental.pallas.tpu.make_async_copy

jax.experimental.pallas.tpu.make_async_copy#

jax.experimental.pallas.tpu.make_async_copy(src_ref, dst_ref, sem)[source]#

Creates a description of an asynchronous copy operation.

Parameters:
  • src_ref – The source Reference.

  • dst_ref – The destination Reference.

  • sem – The semaphore used to track completion of the copy.

Returns:

An AsyncCopyDescriptor.

Return type:

AsyncCopyDescriptor