jax.experimental.pallas.run_scoped#
- jax.experimental.pallas.run_scoped(f, *types, collective_axes=(), **kw_types)[source]#
Calls the function with allocated references and returns the result.
The positional and keyword arguments describe which reference types to allocate for each argument. Each backend has its own set of reference types in addition to
jax.experimental.pallas.MemoryRef.When
collective_axesis specified, the same allocation will be returned for all programs that only differ in their program ids along the collective axes. It is an error not to call the samerun_scopedin all programs along that axis.