jax.experimental.pallas module#
Module for Pallas, a JAX extension for custom kernels.
See the Pallas documentation at https://docs.jax.dev/en/latest/pallas/index.html.
Backends#
Classes#
|
Specifies how an array should be sliced for each invocation of a kernel. |
|
Encodes the grid parameters for |
|
A slice with a start index and a size. |
Functions#
|
Runs a function on a mesh, mapping it over the devices in the mesh. |
|
Entry point for creating a Pallas kernel. |
|
Entry point for creating a Pallas kernel. |
|
Returns the kernel execution position along the given axis of the grid. |
|
Returns the size of the grid along the given axis. |
|
Computes the ceiling division of a divided by b. |
|
Constructs a |
|
Create an empty array of possibly uninitialized values. |
|
Create an empty PyTree of possibly uninitialized values. |
|
Returns an array loaded from the given index. |
|
Stores a value at the given index. |
|
Swaps the value at the given index and returns the old value. |
|
Broadcasts an array to a new shape. |
|
Check the condition if |
|
Prints values from inside a Pallas kernel. |
|
Computes the dot product of two arrays. |
|
Returns a global reference that persists across all kernel invocations. |
|
Returns a decorator that calls the decorated function in a loop. |
|
A compiler hint that asserts the |
|
A compiler hint that asserts a value is a static multiple of another. |
|
Calls the function with allocated references and returns the result. |
|
Calls the decorated function when the condition is met. |
Synchronization#
|
Reads the value of a semaphore. |
|
Increments the value of a semaphore. |
|
Blocks execution of the current thread until a semaphore reaches a value. |