jax.experimental.pallas.semaphore_signal#

jax.experimental.pallas.semaphore_signal(sem_or_view, inc=1, *, device_id=None, device_id_type=DeviceIdType.MESH, core_index=None)[source]#

Increments the value of a semaphore.

This operation can also be performed remotely if device_id is specified, in which sem_or_view refers to a Ref located on another device. Note that it is assumed that sem_or_view is already allocated (e.g. through the proper use of barriers), or else this operation could result in undefined behavior.

Parameters:
  • sem_or_view – A Ref (or view) representing a semaphore.

  • inc (int | jax_typing.Array) – The value to increment by.

  • device_id (optional) – Specifies which device to signal. If not specified, sem_or_view is assumed to be local.

  • device_id_type (optional) – The format in which device_id should be specified.

  • core_index (optional) – If on a multi-core device, specifies which core to signal.