jax.experimental.pallas.mosaic_gpu.as_torch_kernel#
- jax.experimental.pallas.mosaic_gpu.as_torch_kernel(fn)[source]#
Makes a Mosaic GPU kernel callable with PyTorch tensors.
- Parameters:
fn – A JAX function that invokes a Mosaic GPU kernel. Note that the implementation currently only supports functions that contain a single Mosaic GPU kernel invocation, without any other JAX API calls, e.g. from
jax.numpy.- Returns:
A wrapper function that accepts PyTorch tensors as inputs and returns PyTorch tensors as outputs. The output tensors are allocated on the same device as the input tensors.
Example:
@functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def add_kernel(x_ref, y_ref, o_ref): o_ref[...] = x_ref[...] + y_ref[...] x = torch.arange(128, dtype=torch.int32, device="cuda") y = x * x out = plgpu.as_torch_kernel(add_kernel)(x, y)