jax.experimental.pallas.tpu.with_memory_space_constraint#

jax.experimental.pallas.tpu.with_memory_space_constraint(x, memory_space)[source]#

Constrains the memory space of an array.

This primitive does not change the value of x, but it constrains the memory space where it should be allocated. This is useful to force Pallas to allocate an array in a specific memory space.

As of now, this only operates on the inputs pallas_calls, as in you can apply this to the arguments of a pallas_call and it will constrain them, but other operations will not respect this constraint.

Parameters:
  • x (jax.Array) – The array to constrain.

  • memory_space (Any) – The memory space to constrain to.

Returns:

The array x with the memory space constraint.

Return type:

jax.Array