jax.experimental.pallas.triton module

Contents

jax.experimental.pallas.triton module#

Triton-specific Pallas APIs.

Classes#

CompilerParams([num_warps, num_stages])

Compiler parameters for Triton.

Functions#

atomic_and(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] &= val.

atomic_add(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] += val.

atomic_cas(ref, cmp, val)

Performs an atomic compare-and-swap of the value in the ref with the

atomic_max(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] = max(x_ref_or_view[idx], val).

atomic_min(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] = min(x_ref_or_view[idx], val).

atomic_or(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] |= val.

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

Atomically exchanges the given value with the value at the given index.

atomic_xor(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] ^= val.

approx_tanh(x)

Elementwise approximate hyperbolic tangent: \(\mathrm{tanh}(x)\).

debug_barrier()

Synchronizes all kernel executions in the grid.

elementwise_inline_asm(asm, *, args, ...)

Inline assembly applying an elementwise operation.

load(ref, *[, mask, other, cache_modifier, ...])

Loads an array from the given ref.

max_contiguous(x, values)

A compiler hint that asserts the values first values of x are contiguous.

store(ref, val, *[, mask, eviction_policy])

Stores a value to the given ref.