jax.experimental.pallas.tpu.ChipVersion

jax.experimental.pallas.tpu.ChipVersion#

class jax.experimental.pallas.tpu.ChipVersion(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

TPU chip version.

The following table summarizes the differences between TPU versions:

Version

Physical TensorCores per chip

Lite chip

Megacore support

v2

2

No

No

v3

2

No

No

v4i

1

Yes

No

v4

2

No

Yes

v5e

1

Yes

No

v5p

2

No

Yes

v6e

1

Yes

No

7

2

No

No

7x

2

No

No

__init__(*args, **kwds)#

Attributes

num_physical_tensor_cores_per_chip

supports_megacore

is_lite

TPU_V2

TPU_V3

TPU_V4I

TPU_V4

TPU_V5E

TPU_V5P

TPU_V6E

TPU_7

TPU_7X