jax.Array.byteswap

Contents

jax.Array.byteswap#

abstract Array.byteswap()[source]#

Swap the bytes of the array elements.

This switches between a little-endian and big-endian data representation.

Returns:

An array with the same dtype as self, with underlying bytes of each entry reversed.

Parameters:

self (Array)

Return type:

Array

Examples

>>> import jax.numpy as jnp
>>> x = jnp.arange(5, dtype='int32')
>>> x
Array([0, 1, 2, 3, 4], dtype=int32)
>>> x.byteswap()
Array([       0, 16777216, 33554432, 50331648, 67108864], dtype=int32)

When the resulting bytes are viewed as a big-endian dtype (possible in NumPy, but not in JAX) they represent the original values:

>>> import numpy as np
>>> np.array(x.byteswap()).view('>i4')  # view as big-endian
array([0, 1, 2, 3, 4], dtype='>i4')

Calling byteswap twice will return the original array:

>>> x.byteswap().byteswap()
Array([0, 1, 2, 3, 4], dtype=int32)