jax.extend.pallas.GridMapping

jax.extend.pallas.GridMapping#

class jax.extend.pallas.GridMapping(grid, grid_names, block_mappings, index_map_tree, index_map_avals, vmapped_dims, scratch_avals, num_index_operands, num_inputs, num_outputs, get_grid_indices=None, local_grid_env=None, debug=False)[source]#

An internal canonicalized version of GridSpec.

Encodes the calling conventions of the pallas_call primitive, the kernel, and the index maps.

The pallas_call is invoked with: *dynamic_grid_sizes, *index, *inputs. The index operands are for the scalar prefetch.

The kernel function is invoked with: *index, *inputs, *scratch.

The index map functions are invoked with: *program_ids, *index.

See the check_invariants method for a more precise specification.

Parameters:
  • grid (GridMappingGrid)

  • grid_names (tuple[Hashable, ...] | None)

  • block_mappings (tuple[BlockMapping, ...])

  • index_map_tree (tree_util.PyTreeDef)

  • index_map_avals (tuple[jax_core.AbstractValue, ...])

  • vmapped_dims (tuple[int, ...])

  • scratch_avals (tuple[jax_core.AbstractValue, ...])

  • num_index_operands (int)

  • num_inputs (int)

  • num_outputs (int)

  • get_grid_indices (Callable | None)

  • local_grid_env (Callable | None)

  • debug (bool)

__init__(grid, grid_names, block_mappings, index_map_tree, index_map_avals, vmapped_dims, scratch_avals, num_index_operands, num_inputs, num_outputs, get_grid_indices=None, local_grid_env=None, debug=False)#
Parameters:
  • grid (GridMappingGrid)

  • grid_names (tuple[Hashable, ...] | None)

  • block_mappings (tuple[BlockMapping, ...])

  • index_map_tree (tree_util.PyTreeDef)

  • index_map_avals (tuple[jax_core.AbstractValue, ...])

  • vmapped_dims (tuple[int, ...])

  • scratch_avals (tuple[jax_core.AbstractValue, ...])

  • num_index_operands (int)

  • num_inputs (int)

  • num_outputs (int)

  • get_grid_indices (Callable | None)

  • local_grid_env (Callable | None)

  • debug (bool)

Return type:

None

Methods

__init__(grid, grid_names, block_mappings, ...)

check_invariants()

replace(**kwargs)

to_lojax()

trace_env()

Attributes

block_mappings_output

debug

get_grid_indices

in_shapes

The shapes of *index, *inputs.

local_grid_env

num_dynamic_grid_bounds

num_scratch_operands

out_shapes

slice_block_ops

Returns a slice to select the block operands to a kernel.

slice_index_ops

Returns a slice object to select the index operands to a kernel.

slice_scratch_ops

Returns a slice object to select the scratch operands to a kernel.

static_grid

grid

grid_names

block_mappings

index_map_tree

index_map_avals

vmapped_dims

scratch_avals

num_index_operands

num_inputs

num_outputs