jax.experimental.pallas.cdiv

Contents

jax.experimental.pallas.cdiv#

jax.experimental.pallas.cdiv(a: int, b: int) int[source]#
jax.experimental.pallas.cdiv(a: int, b: jax_typing.Array) jax_typing.Array
jax.experimental.pallas.cdiv(a: jax_typing.Array, b: int) jax_typing.Array
jax.experimental.pallas.cdiv(a: jax_typing.Array, b: jax_typing.Array) jax_typing.Array

Computes the ceiling division of a divided by b.

Examples

>>> cdiv(8, 2)
4
>>> cdiv(9, 2)  # 9 / 2 = 4.5, which rounds up to 5
5