jax.extend.mlir.refine_polymorphic_shapes#
- jax.extend.mlir.refine_polymorphic_shapes = <nanobind.nb_func object>#
bool = True, validate_static_shapes: bool = True, enable_shardy: bool = False) -> bytes
- Refines the dynamic shapes for a module.
The “main” function must have static shapes and all the intermediate dynamic shapes depend only on the input static shapes. Optionally, also validates that the resulting module has only static shapes.
- Type:
refine_polymorphic_shapes(mlir_module
- Type:
bytes, enable_shape_assertions