jax.experimental.pallas.tpu.to_pallas_key#

jax.experimental.pallas.tpu.to_pallas_key(key)[source]#

Helper function for converting non-Pallas PRNG keys into Pallas keys.

Parameters:

key (Array)

Return type:

Array