jax.experimental.pallas.tpu.get_tpu_info#
- jax.experimental.pallas.tpu.get_tpu_info()[source]#
Returns the TPU hardware information for the current device.
Note that all information is per-TensorCore so you would need to multiply by num_cores to obtain the total for the chip.
- Returns:
A TpuInfo object containing the hardware information for the current device.
- Return type: