jax.experimental.pallas.mosaic_gpu.query_cluster_cancel

jax.experimental.pallas.mosaic_gpu.query_cluster_cancel#

jax.experimental.pallas.mosaic_gpu.query_cluster_cancel(result_ref, grid_names)[source]#

Decodes the result of a try_cluster_cancel operation.

It interprets the 16-byte opaque response written to shared memory by a completed try_cluster_cancel call to determine if a new work unit was successfully claimed.

Parameters:
  • result_ref (_Ref) – The SMEM ref containing the query response.

  • grid_names (Sequence[Hashable]) – A tuple of grid axis names to query for.

Returns:

  • the grid indices for the requested axis names.

  • A boolean indicating if the cancellation was successful.

Return type:

A tuple containing the decoded response