jax.experimental.pallas.multiple_of#

jax.experimental.pallas.multiple_of(x, values)[source]#

A compiler hint that asserts a value is a static multiple of another.

Note that misusing this function, such as asserting x is a multiple of N when it is not, can result in undefined behavior.

Parameters:
  • x (jax_typing.Array) – The input array.

  • values (Sequence[int] | int) – A set of static divisors that x is a multiple of.

Returns:

A copy of x.

Return type:

jax_typing.Array