jax.experimental.pallas.dslice

Contents

jax.experimental.pallas.dslice#

jax.experimental.pallas.dslice(start, size=None, stride=None)[source]#

Constructs a Slice from a start index and a size.

The semantics of dslice mirror those of the builtin slice type:

  • dslice(None) is :

  • dslice(j) is :j

  • dslice(i, j) is i:i+j

  • dslice(i, j, stride) is i:i+j:stride

Examples

>>> x = jax.numpy.arange(10)
>>> i = 4
>>> x[i: i + 2]  # standard indexing requires i to be static
Array([4, 5], dtype=int32)
>>> x[jax.ds(i, 2)]  # equivalent which allows i to be dynamic
Array([4, 5], dtype=int32)

Here is an explicit example of slicing with a dynamic start index:

>>> @jax.jit(static_argnames='size')
... def f(x, i, size):  # example of when `
...   return x[jax.ds(i, size)]
...
>>> f(x, i, 2)
Array([4, 5], dtype=int32)
Parameters:
Return type:

slice | Slice