jax.experimental.pallas.swap#
- jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[source]#
Swaps the value at the given index and returns the old value.
See
load()for the meaning of the arguments.- Returns:
The value stored in the ref prior to the swap.
- Return type:
jax_typing.Array