jax.experimental.pallas.mosaic_gpu.as_torch_kernel#

jax.experimental.pallas.mosaic_gpu.as_torch_kernel(fn)[source]#

Makes a Mosaic GPU kernel callable with PyTorch tensors.

Parameters:

fn – A JAX function that invokes a Mosaic GPU kernel. Note that the implementation currently only supports functions that contain a single Mosaic GPU kernel invocation, without any other JAX API calls, e.g. from jax.numpy.

Returns:

A wrapper function that accepts PyTorch tensors as inputs and returns PyTorch tensors as outputs. The output tensors are allocated on the same device as the input tensors.

Example:

@functools.partial(
    pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32)
)
def add_kernel(x_ref, y_ref, o_ref):
  o_ref[...] = x_ref[...] + y_ref[...]

x = torch.arange(128, dtype=torch.int32, device="cuda")
y = x * x
out = plgpu.as_torch_kernel(add_kernel)(x, y)