jax.experimental.pallas.get_global#
- jax.experimental.pallas.get_global(what)[source]#
Returns a global reference that persists across all kernel invocations.
Each call to
get_globalreturns a different and unique reference, but one that is stable across invocations of the kernel body.- Parameters:
what (pallas_core.ScratchShape) – The reference type to allocate. Each backend has its own set of reference types (e.g.,
jax.experimental.pallas.mosaic_gpu.SemaphoreTypefor GPU).- Return type:
jax_typing.Array
Example:
sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) pl.semaphore_signal(sem_ref) pl.semaphore_wait(sem_ref)