jax.experimental.pallas.tpu.sample_block#

jax.experimental.pallas.tpu.sample_block(sampler_fn, global_key, block_size, tile_size, total_size, block_index=None, **kwargs)[source]#

Samples a block of random values with invariance guarantees.

sample_block allows the sampling of identical blocks of random values across kernels with different block shapes and iteration orders. Each call to sample_block returns a block_size-shaped array of random samples corresponding to the block_index.

tile_size should be chosen such that it is a divisor to all block sizes one needs to be invariant to. The larger the tile_size, the more efficient the sampling process will be and therefore the best choice is typically the greatest common divisor between all possible block sizes.

Parameters:
  • sampler_fn (SampleFn) – A sampling function that consumes a key and returns random samples.

  • global_key (Array) – The global key to use for sampling.

  • block_size (tuple[int, ...]) – The shape of an individual block.

  • tile_size (tuple[int, ...]) – The shape of a tile, which is the smallest unit at which samples are generated. This should be selected to be a divisor of all block sizes one needs to be invariant to.

  • total_size (tuple[int, ...]) – The total size of the array to sample.

  • block_index (tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, ...] | None) – The index denoting which block to generate keys for. Defaults to the program_id for each block axis.

  • **kwargs – Additional arguments to pass to the sampler_fn.

Returns:

A block_size shaped array of samples for the current block corresponding to block_index.

Return type:

Array