jax.experimental.pallas.loop#
- jax.experimental.pallas.loop(lower: jax_typing.ArrayLike, upper: jax_typing.ArrayLike, *, init_carry: None = None, step: jax_typing.ArrayLike = 1, unroll: int | bool | None = None) Callable[[Callable[[jax_typing.Array], None]], None][source]#
- jax.experimental.pallas.loop(lower: jax_typing.ArrayLike, upper: jax_typing.ArrayLike, *, init_carry: _T = None, step: jax_typing.ArrayLike = 1, unroll: int | bool | None = None) Callable[[Callable[[jax_typing.Array, _T], _T]], _T]
Returns a decorator that calls the decorated function in a loop.