jax.experimental.pallas.tpu module#
Mosaic-specific Pallas APIs.
Classes#
|
|
|
Mosaic TPU compiler parameters. |
|
|
|
|
|
|
|
|
|
TPU hardware information. |
Communication#
|
Issues a DMA copying from src_ref to dst_ref. |
|
Issues a remote DMA copying from src_ref to dst_ref. |
|
Creates a description of an asynchronous copy operation. |
|
Creates a description of a remote copy operation. |
|
Copies a PyTree of Refs to another PyTree of Refs. |
Pipelining#
|
A helper class to automate VMEM double buffering in pallas pipelines. |
Abstract interface for BufferedRefs. |
|
|
Creates a function to emit a manual pallas pipeline. |
|
Creates pallas pipeline and top-level allocation preparation functions. |
|
Retrieve a named pipeline schedule or pass through fully specified one. |
|
Create BufferedRefs for the pipeline. |
Pseudorandom Number Generation#
|
Sets the seed for PRNG. |
|
Samples a block of random values with invariance guarantees. |
|
Sample Bernoulli random values with given shape and mean. |
|
Sample uniform bits in the form of unsigned integers. |
|
Sample standard normal random values with given shape and float dtype. |
|
Sample uniform random values in [minval, maxval) with given shape/dtype. |
|
Helper function for converting non-Pallas PRNG keys into Pallas keys. |
Interpret Mode#
|
Context manager that forces TPU interpret mode under its dynamic context. |
|
Parameters for TPU interpret mode. |
Resets all global, shared state used by TPU interpret mode. |
|
|
Miscellaneous#
|
Synchronizes all cores in a given axis. |
Returns a barrier semaphore. |
|
Returns the TPU hardware information for the current device. |
|
Returns whether the current device is a TPU. |
|
|
Runs a function on the first core in a given axis. |
|
Constrains the memory space of an array. |