jax.experimental.pallas.tpu.get_tpu_info

Contents

jax.experimental.pallas.tpu.get_tpu_info#

jax.experimental.pallas.tpu.get_tpu_info()[source]#

Returns the TPU hardware info for the current device.

Note that all information is per-TensorCore so you would need to multiply by num_cores to obtain the total for the chip.

Return type:

TpuInfo