jax.experimental.pallas.tpu.TpuInfo#
- class jax.experimental.pallas.tpu.TpuInfo(*, chip_version, generation, num_cores, num_lanes, num_sublanes, mxu_column_size, vmem_capacity_bytes, cmem_capacity_bytes, smem_capacity_bytes, hbm_capacity_bytes, mem_bw_bytes_per_second, bf16_ops_per_second, int8_ops_per_second, fp8_ops_per_second, int4_ops_per_second, sparse_core=None)[source]#
TPU hardware information.
Note that all information is per-TensorCore so you would need to multiply by num_cores to obtain the total for the chip.
- Parameters:
chip_version (ChipVersionBase)
generation (int)
num_cores (int)
num_lanes (int)
num_sublanes (int)
mxu_column_size (int)
vmem_capacity_bytes (int)
cmem_capacity_bytes (int)
smem_capacity_bytes (int)
hbm_capacity_bytes (int)
mem_bw_bytes_per_second (int)
bf16_ops_per_second (int)
int8_ops_per_second (int)
fp8_ops_per_second (int)
int4_ops_per_second (int)
sparse_core (SparseCoreInfo | None)
- __init__(*, chip_version, generation, num_cores, num_lanes, num_sublanes, mxu_column_size, vmem_capacity_bytes, cmem_capacity_bytes, smem_capacity_bytes, hbm_capacity_bytes, mem_bw_bytes_per_second, bf16_ops_per_second, int8_ops_per_second, fp8_ops_per_second, int4_ops_per_second, sparse_core=None)#
- Parameters:
chip_version (ChipVersionBase)
generation (int)
num_cores (int)
num_lanes (int)
num_sublanes (int)
mxu_column_size (int)
vmem_capacity_bytes (int)
cmem_capacity_bytes (int)
smem_capacity_bytes (int)
hbm_capacity_bytes (int)
mem_bw_bytes_per_second (int)
bf16_ops_per_second (int)
int8_ops_per_second (int)
fp8_ops_per_second (int)
int4_ops_per_second (int)
sparse_core (SparseCoreInfo | None)
- Return type:
None
Methods
__init__(*, chip_version, generation, ...[, ...])get_sublane_tiling(dtype)Returns the sublane tiling for the given itemsize.
is_matmul_supported(lhs_dtype, rhs_dtype)Returns whether the given matmul input dtypes are supported on the chip.
Attributes
is_liteis_split_chipsparse_corechip_versiongenerationnum_coresnum_lanesnum_sublanesmxu_column_sizevmem_capacity_bytescmem_capacity_bytessmem_capacity_byteshbm_capacity_bytesmem_bw_bytes_per_secondbf16_ops_per_secondint8_ops_per_secondfp8_ops_per_secondint4_ops_per_second