jax.experimental.pallas.with_scoped#
- jax.experimental.pallas.with_scoped(*types, collective_axes=(), **kw_types)[source]#
Returns a function decorator that runs a function with provided allocations.
Example:
@pl.with_scoped(pltpu.VMEM((8, 128), jnp.float32), sem_ref=pltpu.SemaphoreType.DMA) def f(vmem_ref, sem_ref): ... f()
The arguments to f will be forwarded to the decorated function as the initial arguments.
Example:
@pl.with_scoped(pltpu.VMEM((8, 128), jnp.float32), sem_ref=pltpu.SemaphoreType.DMA) def f(outer_ref, vmem_ref, sem_ref): ... outer_ref = ... f(outer_ref)