jax.experimental.pallas.cdiv#
- jax.experimental.pallas.cdiv(a: int, b: int) int[source]#
- jax.experimental.pallas.cdiv(a: int, b: jax_typing.Array) jax_typing.Array
- jax.experimental.pallas.cdiv(a: jax_typing.Array, b: int) jax_typing.Array
- jax.experimental.pallas.cdiv(a: jax_typing.Array, b: jax_typing.Array) jax_typing.Array
Computes the ceiling division of a divided by b.
Examples
>>> cdiv(8, 2) 4 >>> cdiv(9, 2) # 9 / 2 = 4.5, which rounds up to 5 5