diff --git a/jax/flatten_util.py b/jax/flatten_util.py index 7324de1a0fa2..249ff0441587 100644 --- a/jax/flatten_util.py +++ b/jax/flatten_util.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + +import numpy as np from .tree_util import tree_flatten, tree_unflatten -from ._src.util import safe_zip +from ._src.util import safe_zip, unzip2 import jax.numpy as jnp -from jax.api import vjp +from jax import dtypes +from jax import lax zip = safe_zip @@ -26,18 +30,40 @@ def ravel_pytree(pytree): """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. Args: - pytree: a pytree to ravel. + pytree: a pytree of arrays and scalars to ravel. Returns: A pair where the first element is a 1D array representing the flattened and - concatenated leaf values, and the second element is a callable for - unflattening a 1D vector of the same length back to a pytree of of the same - structure as the input ``pytree``. + concatenated leaf values, with dtype determined by promoting the dtypes of + leaf values, and the second element is a callable for unflattening a 1D + vector of the same length back to a pytree of of the same structure as the + input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as + a convention a 1D empty array of dtype float32 is returned in the first + component of the output. + + For details on dtype promotion, see + https://jax.readthedocs.io/en/latest/type_promotion.html. + """ leaves, treedef = tree_flatten(pytree) - flat, unravel_list = vjp(_ravel_list, *leaves) + flat, unravel_list = _ravel_list(leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree -def _ravel_list(*lst): - return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) +def _ravel_list(lst): + if not lst: return jnp.array([], jnp.float32), lambda _: [] + from_dtypes = [dtypes.dtype(l) for l in lst] + to_dtype = dtypes.result_type(*from_dtypes) + sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) + indices = np.cumsum(sizes) + + def unravel(arr): + chunks = jnp.split(arr, indices[:-1]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # ignore complex-to-real cast warning + return [lax.convert_element_type(chunk.reshape(shape), dtype) + for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)] + + ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) + raveled = jnp.concatenate([ravel(e) for e in lst]) + return raveled, unravel diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 6e8ed3d501d5..52f026da0169 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -20,6 +20,9 @@ from jax import test_util as jtu from jax import tree_util +from jax import flatten_util +from jax import dtypes +import jax.numpy as jnp def _dummy_func(*args, **kwargs): @@ -274,5 +277,56 @@ def testTransposeWithCustomObject(self): FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual) + +class RavelUtilTest(jtu.JaxTestCase): + + def testFloats(self): + tree = [jnp.array([3.], jnp.float32), + jnp.array([[1., 2.], [3., 4.]], jnp.float32)] + raveled, unravel = flatten_util.ravel_pytree(tree) + self.assertEqual(raveled.dtype, jnp.float32) + tree_ = unravel(raveled) + self.assertAllClose(tree, tree_, atol=0., rtol=0.) + + def testInts(self): + tree = [jnp.array([3], jnp.int32), + jnp.array([[1, 2], [3, 4]], jnp.int32)] + raveled, unravel = flatten_util.ravel_pytree(tree) + self.assertEqual(raveled.dtype, jnp.int32) + tree_ = unravel(raveled) + self.assertAllClose(tree, tree_, atol=0., rtol=0.) + + def testMixedFloatInt(self): + tree = [jnp.array([3], jnp.int32), + jnp.array([[1., 2.], [3., 4.]], jnp.float32)] + raveled, unravel = flatten_util.ravel_pytree(tree) + self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.int32)) + tree_ = unravel(raveled) + self.assertAllClose(tree, tree_, atol=0., rtol=0.) + + def testMixedIntBool(self): + tree = [jnp.array([0], jnp.bool_), + jnp.array([[1, 2], [3, 4]], jnp.int32)] + raveled, unravel = flatten_util.ravel_pytree(tree) + self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.bool_, jnp.int32)) + tree_ = unravel(raveled) + self.assertAllClose(tree, tree_, atol=0., rtol=0.) + + def testMixedFloatComplex(self): + tree = [jnp.array([1.], jnp.float32), + jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)] + raveled, unravel = flatten_util.ravel_pytree(tree) + self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.complex64)) + tree_ = unravel(raveled) + self.assertAllClose(tree, tree_, atol=0., rtol=0.) + + def testEmpty(self): + tree = [] + raveled, unravel = flatten_util.ravel_pytree(tree) + self.assertEqual(raveled.dtype, jnp.float32) # convention + tree_ = unravel(raveled) + self.assertAllClose(tree, tree_, atol=0., rtol=0.) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())