jax.experimental.pallas.dslice#
- jax.experimental.pallas.dslice(start, size=None, stride=None)[source]#
Constructs a
Slicefrom a start index and a size.The semantics of
dslicemirror those of the builtinslicetype:dslice(None)is:dslice(j)is:jdslice(i, j)isi:i+jdslice(i, j, stride)isi:i+j:stride
Examples
>>> x = jax.numpy.arange(10) >>> i = 4 >>> x[i: i + 2] # standard indexing requires i to be static Array([4, 5], dtype=int32) >>> x[jax.ds(i, 2)] # equivalent which allows i to be dynamic Array([4, 5], dtype=int32)
Here is an explicit example of slicing with a dynamic start index:
>>> @jax.jit(static_argnames='size') ... def f(x, i, size): # example of when ` ... return x[jax.ds(i, size)] ... >>> f(x, i, 2) Array([4, 5], dtype=int32)