jax.experimental.pallas.with_scoped

Contents

jax.experimental.pallas.with_scoped#

jax.experimental.pallas.with_scoped(*types, collective_axes=(), **kw_types)[source]#

Returns a function decorator that runs a function with provided allocations.

Example:

@pl.with_scoped(pltpu.VMEM((8, 128), jnp.float32),
                sem_ref=pltpu.SemaphoreType.DMA)
def f(vmem_ref, sem_ref):
  ...

f()

The arguments to f will be forwarded to the decorated function as the initial arguments.

Example:

@pl.with_scoped(pltpu.VMEM((8, 128), jnp.float32),
                sem_ref=pltpu.SemaphoreType.DMA)
def f(outer_ref, vmem_ref, sem_ref):
  ...

outer_ref = ...
f(outer_ref)
Parameters: