jax.experimental.pallas.tpu.get_tpu_info#

jax.experimental.pallas.tpu.get_tpu_info()[source]#

Returns the TPU hardware information for the current device.

Note that all information is per-TensorCore so you would need to multiply by num_cores to obtain the total for the chip.

Returns:

A TpuInfo object containing the hardware information for the current device.

Return type:

TpuInfo