Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize ravel_pytree to handle int types #6136

Merged
merged 1 commit into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions jax/flatten_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import numpy as np

from .tree_util import tree_flatten, tree_unflatten
from ._src.util import safe_zip
from ._src.util import safe_zip, unzip2

import jax.numpy as jnp
from jax.api import vjp
from jax import dtypes
from jax import lax

zip = safe_zip

Expand All @@ -26,18 +30,40 @@ def ravel_pytree(pytree):
"""Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

Args:
pytree: a pytree to ravel.
pytree: a pytree of arrays and scalars to ravel.

Returns:
A pair where the first element is a 1D array representing the flattened and
concatenated leaf values, and the second element is a callable for
unflattening a 1D vector of the same length back to a pytree of of the same
structure as the input ``pytree``.
concatenated leaf values, with dtype determined by promoting the dtypes of
leaf values, and the second element is a callable for unflattening a 1D
vector of the same length back to a pytree of of the same structure as the
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
a convention a 1D empty array of dtype float32 is returned in the first
component of the output.

For details on dtype promotion, see
https://jax.readthedocs.io/en/latest/type_promotion.html.

"""
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = vjp(_ravel_list, *leaves)
flat, unravel_list = _ravel_list(leaves)
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
return flat, unravel_pytree

def _ravel_list(*lst):
return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])
def _ravel_list(lst):
if not lst: return jnp.array([], jnp.float32), lambda _: []
from_dtypes = [dtypes.dtype(l) for l in lst]
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
to_dtype = dtypes.result_type(*from_dtypes)
sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
indices = np.cumsum(sizes)

def unravel(arr):
chunks = jnp.split(arr, indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
return [lax.convert_element_type(chunk.reshape(shape), dtype)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
raveled = jnp.concatenate([ravel(e) for e in lst])
return raveled, unravel
54 changes: 54 additions & 0 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

from jax import test_util as jtu
from jax import tree_util
from jax import flatten_util
from jax import dtypes
import jax.numpy as jnp


def _dummy_func(*args, **kwargs):
Expand Down Expand Up @@ -274,5 +277,56 @@ def testTransposeWithCustomObject(self):
FlatCache({"a": [3, 4], "b": [5, 6]}))
self.assertEqual(expected, actual)


class RavelUtilTest(jtu.JaxTestCase):

def testFloats(self):
tree = [jnp.array([3.], jnp.float32),
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.float32)
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)

def testInts(self):
tree = [jnp.array([3], jnp.int32),
jnp.array([[1, 2], [3, 4]], jnp.int32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.int32)
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)

def testMixedFloatInt(self):
tree = [jnp.array([3], jnp.int32),
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.int32))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)

def testMixedIntBool(self):
tree = [jnp.array([0], jnp.bool_),
jnp.array([[1, 2], [3, 4]], jnp.int32)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.bool_, jnp.int32))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)

def testMixedFloatComplex(self):
tree = [jnp.array([1.], jnp.float32),
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.complex64))
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)

def testEmpty(self):
tree = []
raveled, unravel = flatten_util.ravel_pytree(tree)
self.assertEqual(raveled.dtype, jnp.float32) # convention
tree_ = unravel(raveled)
self.assertAllClose(tree, tree_, atol=0., rtol=0.)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())