jax.experimental.pallas.tpu.sample_block#
- jax.experimental.pallas.tpu.sample_block(sampler_fn, global_key, block_size, tile_size, total_size, block_index=None, **kwargs)[source]#
Samples a block of random values with invariance guarantees.
sample_blockallows the sampling of identical blocks of random values across kernels with different block shapes and iteration orders. Each call to sample_block returns a block_size-shaped array of random samples corresponding to the block_index.tile_sizeshould be chosen such that it is a divisor to all block sizes one needs to be invariant to. The larger thetile_size, the more efficient the sampling process will be and therefore the best choice is typically the greatest common divisor between all possible block sizes.- Parameters:
sampler_fn (SampleFn) – A sampling function that consumes a key and returns random samples.
global_key (Array) – The global key to use for sampling.
block_size (tuple[int, ...]) – The shape of an individual block.
tile_size (tuple[int, ...]) – The shape of a
tile, which is the smallest unit at which samples are generated. This should be selected to be a divisor of all block sizes one needs to be invariant to.total_size (tuple[int, ...]) – The total size of the array to sample.
block_index (tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, ...] | None) – The index denoting which block to generate keys for. Defaults to the program_id for each block axis.
**kwargs – Additional arguments to pass to the sampler_fn.
- Returns:
A
block_sizeshaped array of samples for the current block corresponding toblock_index.- Return type: