jax.extend.pallas module

jax.extend.pallas module#

GridMapping(grid, grid_names, ...[, ...])

An internal canonicalized version of GridSpec.

register_lowering_rule(params_cls, rule, ...)