jax.experimental.pallas.get_global#

jax.experimental.pallas.get_global(what)[source]#

Returns a global reference that persists across all kernel invocations.

Each call to get_global returns a different and unique reference, but one that is stable across invocations of the kernel body.

Parameters:

what (pallas_core.ScratchShape) – The reference type to allocate. Each backend has its own set of reference types (e.g., jax.experimental.pallas.mosaic_gpu.SemaphoreType for GPU).

Return type:

jax_typing.Array

Example:

sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR)
pl.semaphore_signal(sem_ref)
pl.semaphore_wait(sem_ref)