jax.experimental.pallas.tpu module#

Mosaic-specific Pallas APIs.

Classes#

ChipVersion(value[, names, module, ...])

CompilerParams([dimension_semantics, ...])

Mosaic TPU compiler parameters.

GridDimensionSemantics(value[, names, ...])

MemorySpace(value[, names, module, ...])

PrefetchScalarGridSpec(num_scalar_prefetch)

SemaphoreType(value[, names, module, ...])

TpuInfo(*, chip_version, generation, ...[, ...])

TPU hardware information.

Communication#

async_copy(src_ref, dst_ref, sem, *[, ...])

Issues a DMA copying from src_ref to dst_ref.

async_remote_copy(src_ref, dst_ref, ...[, ...])

Issues a remote DMA copying from src_ref to dst_ref.

make_async_copy(src_ref, dst_ref, sem)

Creates a description of an asynchronous copy operation.

make_async_remote_copy(src_ref, dst_ref, ...)

Creates a description of a remote copy operation.

sync_copy(src_ref, dst_ref)

Copies a PyTree of Refs to another PyTree of Refs.

Pipelining#

BufferedRef(_spec, _buffer_type, window_ref, ...)

A helper class to automate VMEM double buffering in pallas pipelines.

BufferedRefBase()

Abstract interface for BufferedRefs.

emit_pipeline(body, *, grid[, in_specs, ...])

Creates a function to emit a manual pallas pipeline.

emit_pipeline_with_allocations(body, *, grid)

Creates pallas pipeline and top-level allocation preparation functions.

get_pipeline_schedule(schedule)

Retrieve a named pipeline schedule or pass through fully specified one.

make_pipeline_allocations(*refs[, in_specs, ...])

Create BufferedRefs for the pipeline.

Pseudorandom Number Generation#

prng_seed(*seeds)

Sets the seed for PRNG.

sample_block(sampler_fn, global_key, ...[, ...])

Samples a block of random values with invariance guarantees.

stateful_bernoulli(*args, **kwargs)

Sample Bernoulli random values with given shape and mean.

stateful_bits(*args, **kwargs)

Sample uniform bits in the form of unsigned integers.

stateful_normal(*args, **kwargs)

Sample standard normal random values with given shape and float dtype.

stateful_uniform(*args, **kwargs)

Sample uniform random values in [minval, maxval) with given shape/dtype.

to_pallas_key(key)

Helper function for converting non-Pallas PRNG keys into Pallas keys.

Interpret Mode#

force_tpu_interpret_mode([params])

Context manager that forces TPU interpret mode under its dynamic context.

InterpretParams(*[, detect_races, ...])

Parameters for TPU interpret mode.

reset_tpu_interpret_mode_state()

Resets all global, shared state used by TPU interpret mode.

set_tpu_interpret_mode([params])

Miscellaneous#

core_barrier(sem, *, core_axis_name)

Synchronizes all cores in a given axis.

get_barrier_semaphore()

Returns a barrier semaphore.

get_tpu_info()

Returns the TPU hardware information for the current device.

is_tpu_device()

Returns whether the current device is a TPU.

run_on_first_core(core_axis_name)

Runs a function on the first core in a given axis.

with_memory_space_constraint(x, memory_space)

Constrains the memory space of an array.