jax.dtypes module#
bfloat16 floating-point values |
|
|
Convert from a dtype to a canonical dtype based on config.x64_enabled. |
VoidDType(length, /) |
|
|
Number of bits per element for the dtype. |
|
Returns True if first argument is a typecode lower/equal in type hierarchy. |
|
Scalar class for PRNG Key dtypes. |
|
Convenience function to apply JAX argument dtype promotion. |
Return the scalar type associated with a JAX value. |
|
Raised when JAX type promotion fails. |