jax.extend.pallas.GridMapping#
- class jax.extend.pallas.GridMapping(grid, grid_names, block_mappings, index_map_tree, index_map_avals, vmapped_dims, scratch_avals, num_index_operands, num_inputs, num_outputs, get_grid_indices=None, local_grid_env=None, debug=False)[source]#
An internal canonicalized version of GridSpec.
Encodes the calling conventions of the pallas_call primitive, the kernel, and the index maps.
The pallas_call is invoked with:
*dynamic_grid_sizes, *index, *inputs. Theindexoperands are for the scalar prefetch.The kernel function is invoked with:
*index, *inputs, *scratch.The index map functions are invoked with:
*program_ids, *index.See the check_invariants method for a more precise specification.
- Parameters:
grid (GridMappingGrid)
grid_names (tuple[Hashable, ...] | None)
block_mappings (tuple[BlockMapping, ...])
index_map_tree (tree_util.PyTreeDef)
index_map_avals (tuple[jax_core.AbstractValue, ...])
scratch_avals (tuple[jax_core.AbstractValue, ...])
num_index_operands (int)
num_inputs (int)
num_outputs (int)
get_grid_indices (Callable | None)
local_grid_env (Callable | None)
debug (bool)
- __init__(grid, grid_names, block_mappings, index_map_tree, index_map_avals, vmapped_dims, scratch_avals, num_index_operands, num_inputs, num_outputs, get_grid_indices=None, local_grid_env=None, debug=False)#
- Parameters:
grid (GridMappingGrid)
grid_names (tuple[Hashable, ...] | None)
block_mappings (tuple[BlockMapping, ...])
index_map_tree (tree_util.PyTreeDef)
index_map_avals (tuple[jax_core.AbstractValue, ...])
scratch_avals (tuple[jax_core.AbstractValue, ...])
num_index_operands (int)
num_inputs (int)
num_outputs (int)
get_grid_indices (Callable | None)
local_grid_env (Callable | None)
debug (bool)
- Return type:
None
Methods
__init__(grid, grid_names, block_mappings, ...)check_invariants()replace(**kwargs)to_lojax()trace_env()Attributes
block_mappings_outputdebugget_grid_indicesin_shapesThe shapes of
*index,*inputs.local_grid_envnum_dynamic_grid_boundsnum_scratch_operandsout_shapesslice_block_opsReturns a slice to select the block operands to a kernel.
slice_index_opsReturns a slice object to select the index operands to a kernel.
slice_scratch_opsReturns a slice object to select the scratch operands to a kernel.
static_gridgridgrid_namesblock_mappingsindex_map_treeindex_map_avalsvmapped_dimsscratch_avalsnum_index_operandsnum_inputsnum_outputs