jax.experimental.pallas module#

Module for Pallas, a JAX extension for custom kernels.

See the Pallas documentation at https://docs.jax.dev/en/latest/pallas/index.html.

Backends#

Classes#

BlockSpec([block_shape, index_map, ...])

Specifies how an array should be sliced for each invocation of a kernel.

GridSpec([grid, in_specs, out_specs, ...])

Encodes the grid parameters for jax.experimental.pallas.pallas_call().

Slice(start, size[, stride])

A slice with a start index and a size.

Functions#

core_map(mesh, *[, compiler_params, ...])

Runs a function on a mesh, mapping it over the devices in the mesh.

kernel([body, out_shape, scratch_shapes, ...])

Entry point for creating a Pallas kernel.

pallas_call(kernel, out_shape, *[, ...])

Entry point for creating a Pallas kernel.

program_id(axis)

Returns the kernel execution position along the given axis of the grid.

num_programs(axis)

Returns the size of the grid along the given axis.

cdiv()

Computes the ceiling division of a divided by b.

dslice(start[, size, stride])

Constructs a Slice from a start index and a size.

empty(shape, dtype, *[, out_sharding])

Create an empty array of possibly uninitialized values.

empty_like(x)

Create an empty PyTree of possibly uninitialized values.

load(x_ref_or_view, idx, *[, mask, other, ...])

Returns an array loaded from the given index.

store(x_ref_or_view, idx, val, *[, mask, ...])

Stores a value at the given index.

swap(x_ref_or_view, idx, val, *[, mask, ...])

Swaps the value at the given index and returns the old value.

broadcast_to(a, shape)

Broadcasts an array to a new shape.

debug_check(condition, message)

Check the condition if enable_debug_checks() is set, otherwise do nothing.

debug_print(fmt, *args)

Prints values from inside a Pallas kernel.

dot(a, b[, trans_a, trans_b, allow_tf32, ...])

Computes the dot product of two arrays.

get_global(what)

Returns a global reference that persists across all kernel invocations.

loop(lower, upper, *[, step, unroll])

Returns a decorator that calls the decorated function in a loop.

max_contiguous(x, values)

A compiler hint that asserts the values first values of x are contiguous.

multiple_of(x, values)

A compiler hint that asserts a value is a static multiple of another.

run_scoped(f, *types[, collective_axes])

Calls the function with allocated references and returns the result.

when(condition, /)

Calls the decorated function when the condition is met.

Synchronization#

semaphore_read(sem_or_view)

Reads the value of a semaphore.

semaphore_signal(sem_or_view[, inc, ...])

Increments the value of a semaphore.

semaphore_wait(sem_or_view[, value, decrement])

Blocks execution of the current thread until a semaphore reaches a value.