jax.experimental.pallas.tpu.PrefetchScalarGridSpec#

class jax.experimental.pallas.tpu.PrefetchScalarGridSpec(num_scalar_prefetch: 'int', grid: 'pallas_core.Grid' = (), in_specs: 'pallas_core.BlockSpecTree' = NoBlockSpec, out_specs: 'pallas_core.BlockSpecTree' = NoBlockSpec, scratch_shapes: 'pallas_core.ScratchShapeTree' = ())[source]#
Parameters:
  • num_scalar_prefetch (int)

  • grid (TupleGrid)

  • in_specs (BlockSpecTree)

  • out_specs (BlockSpecTree)

  • scratch_shapes (ScratchShapeTree)

__init__(num_scalar_prefetch, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=())[source]#
Parameters:
  • num_scalar_prefetch (int)

  • grid (pallas_core.Grid)

  • in_specs (pallas_core.BlockSpecTree)

  • out_specs (pallas_core.BlockSpecTree)

  • scratch_shapes (pallas_core.ScratchShapeTree)

Methods

__init__(num_scalar_prefetch[, grid, ...])

Attributes

scratch_shapes

num_scalar_prefetch

grid

grid_names

in_specs

out_specs