jax.experimental.pallas.semaphore_wait#

jax.experimental.pallas.semaphore_wait(sem_or_view, value=1, *, decrement=True)[source]#

Blocks execution of the current thread until a semaphore reaches a value.

Parameters:
  • sem_or_view – A Ref (or view) representing a semaphore.

  • value (int | jax_typing.Array) – The target value that the semaphore should reach before unblocking.

  • decrement (bool) – Whether to decrement the value of the semaphore after a successful wait.