jax.experimental.pallas.mosaic_gpu.semaphore_signal_parallel#

jax.experimental.pallas.mosaic_gpu.semaphore_signal_parallel(*signals)[source]#

Signals multiple semaphores without any guaranteed ordering of signal arrivals.

This primitive is largely equivalent to:

for sem in semaphores:
  pl.semaphore_signal(sem, inc, device_id=device_id)

only unlike the loop above, it does not guarantee any ordering of signal arrivals. In particular, the target device might observe a signal on semaphores[1] before it observes a signal on semaphores[0]. This operation still guarantees that any side effects performed before the signal will be fully performed and visible before any of the signals arrive.

The relaxed requirements make the whole operation significantly cheaper on GPUs, as a single expensive memory fence can be used for all signals (instead of an expensive fence for each signal).

Parameters:

signals (SemaphoreSignal)