jax.extend.mlir.refine_polymorphic_shapes#

jax.extend.mlir.refine_polymorphic_shapes = <nanobind.nb_func object>#

bool = True, validate_static_shapes: bool = True, enable_shardy: bool = False) -> bytes

Refines the dynamic shapes for a module.

The “main” function must have static shapes and all the intermediate dynamic shapes depend only on the input static shapes. Optionally, also validates that the resulting module has only static shapes.

Type:

refine_polymorphic_shapes(mlir_module

Type:

bytes, enable_shape_assertions