jax.extend.pallas.register_lowering_rule

jax.extend.pallas.register_lowering_rule#

jax.extend.pallas.register_lowering_rule(params_cls, rule, platform)[source]#
Parameters:

platform (str)