jax.experimental.pallas.mosaic_gpu.wait_smem_to_gmem#

jax.experimental.pallas.mosaic_gpu.wait_smem_to_gmem(n, wait_read_only=False)[source]#

Waits until no more than the most recent n SMEM->GMEM copies issued by the calling thread are in flight.

Parameters:
  • n (int) – The maximum number of copies in flight to wait for.

  • wait_read_only (bool) – If True, wait for the in flight copies to finish reading from SMEM. The writes to GMEM are not waited for.

Return type:

None