diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 33b290b4..759555e8 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -23,6 +23,7 @@ Tree Manipulation Functions .. autosummary:: + dict_insertion_ordered tree_flatten tree_flatten_with_path tree_flatten_with_accessor @@ -55,6 +56,7 @@ Tree Manipulation Functions tree_flatten_one_level prefix_errors +.. autofunction:: dict_insertion_ordered .. autofunction:: tree_flatten .. autofunction:: tree_flatten_with_path .. autofunction:: tree_flatten_with_accessor diff --git a/include/treespec.h b/include/treespec.h index 939a702b..1f05b517 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -202,6 +202,23 @@ class PyTreeSpec { const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); + // Check if should preserve the insertion order of the dictionary keys during flattening. + inline static bool IsDictInsertionOrdered(const std::string ®istry_namespace) { + return (sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end() || + sm_is_dict_insertion_ordered.find(registry_namespace) != + sm_is_dict_insertion_ordered.end()); + } + + // Set the namespace to preserve the insertion order of the dictionary keys during flattening. + inline static void SetDictInsertionOrdered(const bool &mode, + const std::string ®istry_namespace) { + if (mode) [[likely]] { + sm_is_dict_insertion_ordered.insert(registry_namespace); + } else [[unlikely]] { + sm_is_dict_insertion_ordered.erase(registry_namespace); + } + } + private: using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr; @@ -266,7 +283,7 @@ class PyTreeSpec { const bool &none_is_leaf, const std::string ®istry_namespace); - template + template bool FlattenIntoImpl(const py::handle &handle, Span &leaves, // NOLINT[runtime/references] const ssize_t &depth, @@ -281,7 +298,11 @@ class PyTreeSpec { const bool &none_is_leaf, const std::string ®istry_namespace); - template + template bool FlattenIntoWithPathImpl(const py::handle &handle, LeafSpan &leaves, // NOLINT[runtime/references] PathSpan &paths, // NOLINT[runtime/references] @@ -329,6 +350,10 @@ class PyTreeSpec { size_t operator()(const std::pair &p) const; }; + // A set of namespaces that preserve the insertion order of the dictionary keys during + // flattening. + inline static std::unordered_set sm_is_dict_insertion_ordered{}; + // A set of (treespec, thread_id) pairs that are currently being represented as strings. inline static std::unordered_set, ThreadIndentTypeHash> @@ -344,12 +369,13 @@ class PyTreeIter { public: PyTreeIter(const py::object &tree, const std::optional &leaf_predicate, - bool none_is_leaf, - std::string registry_namespace) + const bool &none_is_leaf, + const std::string ®istry_namespace) : m_agenda({std::make_pair(tree, 0)}), m_leaf_predicate(leaf_predicate), m_none_is_leaf(none_is_leaf), - m_namespace(std::move(registry_namespace)) {}; + m_namespace(registry_namespace), + m_is_dict_insertion_ordered(PyTreeSpec::IsDictInsertionOrdered(registry_namespace)) {}; PyTreeIter() = delete; @@ -377,6 +403,7 @@ class PyTreeIter { std::optional m_leaf_predicate; bool m_none_is_leaf; std::string m_namespace; + bool m_is_dict_insertion_ordered; template [[nodiscard]] py::object NextImpl(); diff --git a/optree/_C.pyi b/optree/_C.pyi index 8fd2f49c..83211795 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -158,3 +158,5 @@ def unregister_node( cls: type[CustomTreeNode[T]], namespace: str = '', ) -> None: ... +def is_dict_insertion_ordered(namespace: str = '') -> bool: ... +def set_dict_insertion_ordered(mode: bool, namespace: str = '') -> None: ... diff --git a/optree/__init__.py b/optree/__init__.py index 6a094535..d254b0b2 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -95,6 +95,7 @@ from optree.registry import ( AttributeKeyPathEntry, GetitemKeyPathEntry, + dict_insertion_ordered, register_keypaths, register_pytree_node, register_pytree_node_class, @@ -200,6 +201,7 @@ 'register_pytree_node', 'register_pytree_node_class', 'unregister_pytree_node', + 'dict_insertion_ordered', # Typing 'PyTreeSpec', 'PyTreeDef', diff --git a/optree/registry.py b/optree/registry.py index 4360e169..59245b76 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -16,6 +16,7 @@ from __future__ import annotations +import contextlib import dataclasses import functools import inspect @@ -29,6 +30,7 @@ Any, Callable, ClassVar, + Generator, Iterable, NamedTuple, Sequence, @@ -70,6 +72,7 @@ 'register_pytree_node', 'register_pytree_node_class', 'unregister_pytree_node', + 'dict_insertion_ordered', ] @@ -491,6 +494,17 @@ def _dict_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]: return dict(safe_zip(keys, values)) +def _dict_insertion_ordered_flatten( + dct: dict[KT, VT], +) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: + keys, values = unzip2(dct.items()) + return values, list(keys), keys + + +def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]: + return dict(safe_zip(keys, values)) + + def _ordereddict_flatten( dct: OrderedDict[KT, VT], ) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: @@ -517,6 +531,21 @@ def _defaultdict_unflatten( return defaultdict(default_factory, _dict_unflatten(keys, values)) +def _defaultdict_insertion_ordered_flatten( + dct: defaultdict[KT, VT], +) -> tuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]: + values, keys, entries = _dict_insertion_ordered_flatten(dct) + return values, (dct.default_factory, keys), entries + + +def _defaultdict_insertion_ordered_unflatten( + metadata: tuple[Callable[[], VT], list[KT]], + values: Iterable[VT], +) -> defaultdict[KT, VT]: + default_factory, keys = metadata + return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values)) + + def _deque_flatten(deq: deque[T]) -> tuple[deque[T], int | None]: return deq, deq.maxlen @@ -566,6 +595,23 @@ def _pytree_node_registry_get( handler = _NODETYPE_REGISTRY.get((namespace, cls)) if handler is not None: return handler + + if _C.is_dict_insertion_ordered(namespace): + if cls is dict: + return PyTreeNodeRegistryEntry( + dict, + _dict_insertion_ordered_flatten, # type: ignore[arg-type] + _dict_insertion_ordered_unflatten, # type: ignore[arg-type] + path_entry_type=MappingEntry, + ) + if cls is defaultdict: + return PyTreeNodeRegistryEntry( + defaultdict, + _defaultdict_insertion_ordered_flatten, # type: ignore[arg-type] + _defaultdict_insertion_ordered_unflatten, # type: ignore[arg-type] + path_entry_type=MappingEntry, + ) + handler = _NODETYPE_REGISTRY.get(cls) if handler is not None: return handler @@ -580,6 +626,49 @@ def _pytree_node_registry_get( del _pytree_node_registry_get +@contextlib.contextmanager +def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None, None, None]: + """Context manager to temporarily set the dictionary sorting mode. + + This context manager is used to temporarily set the dictionary sorting mode for a specific + namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary + should be sorted or keeping the insertion order when flattening a pytree. + + >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} + >>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE + ( + [1, 2, 3, 4, 5], + PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) + ) + >>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE + ... tree_flatten(tree, namespace='some-namespace') + ( + [2, 3, 4, 1, 5], + PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}) + ) + + Args: + mode (bool): The dictionary sorting mode to set. + namespace (str): The namespace to set the dictionary sorting mode for. + """ + if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): + raise TypeError(f'The namespace must be a string, got {namespace!r}.') + if namespace == '': + raise ValueError('The namespace cannot be an empty string.') + if namespace is __GLOBAL_NAMESPACE: + namespace = '' + + with __REGISTRY_LOCK: + prev = _C.is_dict_insertion_ordered(namespace) + _C.set_dict_insertion_ordered(mode, namespace) + + try: + yield + finally: + with __REGISTRY_LOCK: + _C.set_dict_insertion_ordered(prev, namespace) + + #################################################################################################### warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=True) diff --git a/src/optree.cpp b/src/optree.cpp index 3180d70f..fdbb0639 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -67,6 +67,15 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] "Unregister a Python type.", py::arg("cls"), py::arg("namespace") = "") + .def("is_dict_insertion_ordered", + &PyTreeSpec::IsDictInsertionOrdered, + "Return whether need to preserve the dict insertion order during flattening.", + py::arg("namespace") = "") + .def("set_dict_insertion_ordered", + &PyTreeSpec::SetDictInsertionOrdered, + "Set whether need to preserve the dict insertion order during flattening.", + py::arg("mode"), + py::arg("namespace") = "") .def("flatten", &PyTreeSpec::Flatten, "Flattens a pytree.", diff --git a/src/treespec/constructor.cpp b/src/treespec/constructor.cpp index 5a783e95..d4c594b3 100644 --- a/src/treespec/constructor.cpp +++ b/src/treespec/constructor.cpp @@ -167,7 +167,9 @@ template py::list keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); - TotalOrderSort(keys); + if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + TotalOrderSort(keys); + } } for (const py::handle& key : keys) { children.emplace_back(dict[key]); diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index b8d332b5..e4659266 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -31,7 +31,7 @@ limitations under the License. namespace optree { -template +template // NOLINTNEXTLINE[readability-function-cognitive-complexity] bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, Span& leaves, @@ -56,7 +56,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, // NOLINTNEXTLINE[misc-no-recursion] auto recurse = [this, &found_custom, &leaf_predicate, ®istry_namespace, &leaves, &depth]( const py::handle& child) -> void { - found_custom |= FlattenIntoImpl( + found_custom |= FlattenIntoImpl( child, leaves, depth + 1, leaf_predicate, registry_namespace); }; switch (node.kind) { @@ -98,7 +98,9 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, py::list keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); - TotalOrderSort(keys); + if constexpr (DictShouldBeSorted) { + TotalOrderSort(keys); + } } for (const py::handle& key : keys) { recurse(dict[key]); @@ -184,9 +186,21 @@ bool PyTreeSpec::FlattenInto(const py::handle& handle, const bool& none_is_leaf, const std::string& registry_namespace) { if (none_is_leaf) [[unlikely]] { - return FlattenIntoImpl(handle, leaves, 0, leaf_predicate, registry_namespace); + if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + return FlattenIntoImpl( + handle, leaves, 0, leaf_predicate, registry_namespace); + } else [[unlikely]] { + return FlattenIntoImpl( + handle, leaves, 0, leaf_predicate, registry_namespace); + } } else [[likely]] { - return FlattenIntoImpl(handle, leaves, 0, leaf_predicate, registry_namespace); + if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + return FlattenIntoImpl( + handle, leaves, 0, leaf_predicate, registry_namespace); + } else [[unlikely]] { + return FlattenIntoImpl( + handle, leaves, 0, leaf_predicate, registry_namespace); + } } } @@ -198,15 +212,19 @@ bool PyTreeSpec::FlattenInto(const py::handle& handle, auto leaves = reserved_vector(4); auto treespec = std::make_unique(); treespec->m_none_is_leaf = none_is_leaf; - if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace)) - [[unlikely]] { + if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace) || + IsDictInsertionOrdered(registry_namespace)) [[unlikely]] { treespec->m_namespace = registry_namespace; } treespec->m_traversal.shrink_to_fit(); return std::make_pair(std::move(leaves), std::move(treespec)); } -template +template // NOLINTNEXTLINE[readability-function-cognitive-complexity] bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, LeafSpan& leaves, @@ -245,7 +263,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, &stack, &depth](const py::handle& child, const py::handle& entry) -> void { stack.emplace_back(entry); - found_custom |= FlattenIntoWithPathImpl( + found_custom |= FlattenIntoWithPathImpl( child, leaves, paths, stack, depth + 1, leaf_predicate, registry_namespace); stack.pop_back(); }; @@ -292,7 +310,9 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, py::list keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); - TotalOrderSort(keys); + if constexpr (DictShouldBeSorted) { + TotalOrderSort(keys); + } } for (const py::handle& key : keys) { recurse(dict[key], key); @@ -396,11 +416,21 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, const std::string& registry_namespace) { auto stack = reserved_vector(4); if (none_is_leaf) [[unlikely]] { - return FlattenIntoWithPathImpl( - handle, leaves, paths, stack, 0, leaf_predicate, registry_namespace); + if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + return FlattenIntoWithPathImpl( + handle, leaves, paths, stack, 0, leaf_predicate, registry_namespace); + } else [[unlikely]] { + return FlattenIntoWithPathImpl( + handle, leaves, paths, stack, 0, leaf_predicate, registry_namespace); + } } else [[likely]] { - return FlattenIntoWithPathImpl( - handle, leaves, paths, stack, 0, leaf_predicate, registry_namespace); + if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + return FlattenIntoWithPathImpl( + handle, leaves, paths, stack, 0, leaf_predicate, registry_namespace); + } else [[unlikely]] { + return FlattenIntoWithPathImpl( + handle, leaves, paths, stack, 0, leaf_predicate, registry_namespace); + } } } @@ -414,7 +444,8 @@ PyTreeSpec::FlattenWithPath(const py::object& tree, auto treespec = std::make_unique(); treespec->m_none_is_leaf = none_is_leaf; if (treespec->FlattenIntoWithPath( - tree, leaves, paths, leaf_predicate, none_is_leaf, registry_namespace)) [[unlikely]] { + tree, leaves, paths, leaf_predicate, none_is_leaf, registry_namespace) || + IsDictInsertionOrdered(registry_namespace)) [[unlikely]] { treespec->m_namespace = registry_namespace; } treespec->m_traversal.shrink_to_fit(); diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 8762a56a..db6035af 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -83,7 +83,7 @@ py::object PyTreeIter::NextImpl() { case PyTreeKind::DefaultDict: { auto dict = py::reinterpret_borrow(object); py::list keys = DictKeys(dict); - if (kind != PyTreeKind::OrderedDict) [[likely]] { + if (kind != PyTreeKind::OrderedDict && !m_is_dict_insertion_ordered) [[likely]] { TotalOrderSort(keys); } if (PyList_Reverse(keys.ptr()) < 0) [[unlikely]] { diff --git a/tests/helpers.py b/tests/helpers.py index 2612564b..c6ad6f07 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -36,6 +36,9 @@ ) +GLOBAL_NAMESPACE = optree.registry.__GLOBAL_NAMESPACE # pylint: disable=protected-access + + def gc_collect(): for _ in range(3): gc.collect() @@ -154,12 +157,11 @@ def __call__(self, obj): lambda o: ((o.x, o.y), o.z), lambda z, xy: Vector3D(xy[0], xy[1], z), path_entry_type=Vector3DEntry, - namespace=optree.registry.__GLOBAL_NAMESPACE, # pylint: disable=protected-access + namespace=GLOBAL_NAMESPACE, ) -# pylint: disable-next=protected-access -@optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) +@optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class Vector2D: def __init__(self, x, y): self.x = x @@ -186,8 +188,7 @@ def tree_unflatten(cls, metadata, children): # pylint: disable=unused-argument return cls(*children) -# pylint: disable-next=protected-access -@optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) +@optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) @dataclasses.dataclass class MyDataclass: alpha: Any @@ -207,10 +208,7 @@ def tree_unflatten(cls, metadata, children): return cls(*children) -@optree.register_pytree_node_class( - path_entry_type=optree.GetAttrEntry, - namespace=optree.registry.__GLOBAL_NAMESPACE, # pylint: disable=protected-access -) +@optree.register_pytree_node_class(path_entry_type=optree.GetAttrEntry, namespace=GLOBAL_NAMESPACE) @dataclasses.dataclass class MyOtherDataclass: a: Any @@ -232,8 +230,7 @@ def tree_unflatten(cls, metadata, children): return cls(a, b, c, d) -# pylint: disable-next=protected-access -@optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) +@optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) @dataclasses.dataclass class MyAnotherDataclass: x: Any @@ -248,8 +245,7 @@ def tree_unflatten(cls, metadata, children): return cls(*children) -# pylint: disable-next=protected-access -@optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) +@optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class FlatCache: TREE_PATH_ENTRY_TYPE = optree.GetItemEntry @@ -288,9 +284,7 @@ def tree_unflatten(cls, metadata, children): return cls(structured=None, leaves=children, treespec=metadata) -@optree.register_pytree_node_class( - namespace=optree.registry.__GLOBAL_NAMESPACE, # pylint: disable=protected-access -) +@optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class MyDict(UserDict): TREE_PATH_ENTRY_TYPE = optree.MappingEntry diff --git a/tests/test_functools.py b/tests/test_functools.py index 6e79e956..a77e1463 100644 --- a/tests/test_functools.py +++ b/tests/test_functools.py @@ -18,7 +18,7 @@ import functools import optree -from helpers import parametrize +from helpers import GLOBAL_NAMESPACE, parametrize def dummy_func(*args, **kwargs): # pylint: disable=unused-argument @@ -38,13 +38,27 @@ def dummy_func(*args, **kwargs): # pylint: disable=unused-argument optree.functools.partial(dummy_partial_func, 1, 2, 3, x=4, y=5), ], none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_partial_round_trip(tree, none_is_leaf): - leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) - actual = optree.tree_unflatten(treespec, leaves) - assert actual.func == tree.func - assert actual.args == tree.args - assert actual.keywords == tree.keywords +def test_partial_round_trip( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) + actual = optree.tree_unflatten(treespec, leaves) + assert actual.func == tree.func + assert actual.args == tree.args + assert actual.keywords == tree.keywords + assert tuple(actual.keywords.items()) == tuple(tree.keywords.items()) def test_partial_does_not_merge_with_other_partials(): @@ -78,13 +92,27 @@ def test_partial_func_attribute_has_stable_hash(): optree.Partial(dummy_partial_func, 1, 2, 3, x=4, y=5), ], none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_Partial_round_trip(tree, none_is_leaf): # noqa: N802 - leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) - actual = optree.tree_unflatten(treespec, leaves) - assert actual.func == tree.func - assert actual.args == tree.args - assert actual.keywords == tree.keywords +def test_Partial_round_trip( # noqa: N802 + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) + actual = optree.tree_unflatten(treespec, leaves) + assert actual.func == tree.func + assert actual.args == tree.args + assert actual.keywords == tree.keywords + assert tuple(actual.keywords.items()) == tuple(tree.keywords.items()) def test_Partial_does_not_merge_with_other_partials(): # noqa: N802 diff --git a/tests/test_ops.py b/tests/test_ops.py index f4dd66dc..421b3eac 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -27,6 +27,7 @@ import optree from helpers import ( + GLOBAL_NAMESPACE, IS_LEAF_FUNCTIONS, LEAVES, TREE_ACCESSORS, @@ -81,24 +82,48 @@ def test_max_depth(): tree=list(TREES + LEAVES), none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_round_trip(tree, none_is_leaf, namespace): - leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) - actual = optree.tree_unflatten(treespec, leaves) - assert actual == tree +def test_round_trip( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) + actual = optree.tree_unflatten(treespec, leaves) + assert actual == tree @parametrize( tree=list(TREES + LEAVES), none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_round_trip_with_flatten_up_to(tree, none_is_leaf, namespace): - _, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) - leaves = treespec.flatten_up_to(tree) - actual = optree.tree_unflatten(treespec, leaves) - assert actual == tree - assert leaves == [accessor(tree) for accessor in optree.treespec_accessors(treespec)] +def test_round_trip_with_flatten_up_to( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + _, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) + leaves = treespec.flatten_up_to(tree) + actual = optree.tree_unflatten(treespec, leaves) + assert actual == tree + assert leaves == [accessor(tree) for accessor in optree.treespec_accessors(treespec)] @parametrize( @@ -114,7 +139,6 @@ def test_round_trip_with_flatten_up_to(tree, none_is_leaf, namespace): ) def test_flatten_order(tree, none_is_leaf): flat, _ = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) - assert flat == list(range(10)) @@ -163,14 +187,26 @@ def test_tree_unflatten_mismatch_number_of_leaves(tree, none_is_leaf, namespace) tree=list(TREES + LEAVES), none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_tree_iter(tree, none_is_leaf, namespace): - leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) - it = optree.tree_iter(tree, none_is_leaf=none_is_leaf, namespace=namespace) - assert iter(it) is it - assert list(it) == leaves - with pytest.raises(StopIteration): - next(it) +def test_tree_iter( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) + it = optree.tree_iter(tree, none_is_leaf=none_is_leaf, namespace=namespace) + assert iter(it) is it + assert list(it) == leaves + with pytest.raises(StopIteration): + next(it) def test_walk(): @@ -385,73 +421,86 @@ def test_paths_and_accessors(data): is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_paths_and_accessors_with_is_leaf(tree, is_leaf, none_is_leaf, namespace): - expected_leaves, expected_treespec = optree.tree_flatten( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - paths, leaves, treespec = optree.tree_flatten_with_path( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - accessors, other_leaves, other_treespec = optree.tree_flatten_with_accessor( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert len(paths) == len(leaves) - assert len(accessors) == len(leaves) - assert leaves == expected_leaves - assert treespec == expected_treespec - assert other_leaves == expected_leaves - assert other_treespec == expected_treespec - for leaf, accessor, path in zip(leaves, accessors, paths): - assert isinstance(accessor, optree.PyTreeAccessor) - assert isinstance(path, tuple) - assert len(accessor) == len(path) - assert all( - isinstance(e, optree.PyTreeEntry) - and isinstance(e.type, type) - and isinstance(e.kind, optree.PyTreeKind) - for e in accessor +def test_paths_and_accessors_with_is_leaf( + tree, + is_leaf, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + expected_leaves, expected_treespec = optree.tree_flatten( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, ) - assert accessor.path == path - assert tuple(e.entry for e in accessor) == path - assert accessor(tree) == leaf - if all(e.__class__.codify is not optree.PyTreeEntry.codify for e in accessor): - # pylint: disable-next=eval-used - assert eval(accessor.codify('__tree'), {'__tree': tree}, {}) == leaf - # pylint: disable-next=eval-used - assert eval(f'lambda __tree: {accessor.codify("__tree")}', {}, {})(tree) == leaf - else: - assert 'flat index' in accessor.codify('') - - assert optree.treespec_paths(treespec) == paths - assert optree.treespec_accessors(treespec) == accessors - assert ( - optree.tree_paths( + paths, leaves, treespec = optree.tree_flatten_with_path( tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, ) - == paths - ) - assert ( - optree.tree_accessors( + accessors, other_leaves, other_treespec = optree.tree_flatten_with_accessor( tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, ) - == accessors - ) + assert len(paths) == len(leaves) + assert len(accessors) == len(leaves) + assert leaves == expected_leaves + assert treespec == expected_treespec + assert other_leaves == expected_leaves + assert other_treespec == expected_treespec + for leaf, accessor, path in zip(leaves, accessors, paths): + assert isinstance(accessor, optree.PyTreeAccessor) + assert isinstance(path, tuple) + assert len(accessor) == len(path) + assert all( + isinstance(e, optree.PyTreeEntry) + and isinstance(e.type, type) + and isinstance(e.kind, optree.PyTreeKind) + for e in accessor + ) + assert accessor.path == path + assert tuple(e.entry for e in accessor) == path + assert accessor(tree) == leaf + if all(e.__class__.codify is not optree.PyTreeEntry.codify for e in accessor): + # pylint: disable-next=eval-used + assert eval(accessor.codify('__tree'), {'__tree': tree}, {}) == leaf + # pylint: disable-next=eval-used + assert eval(f'lambda __tree: {accessor.codify("__tree")}', {}, {})(tree) == leaf + else: + assert 'flat index' in accessor.codify('') + + assert optree.treespec_paths(treespec) == paths + assert optree.treespec_accessors(treespec) == accessors + assert ( + optree.tree_paths( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + == paths + ) + assert ( + optree.tree_accessors( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + == accessors + ) @parametrize( @@ -459,40 +508,77 @@ def test_paths_and_accessors_with_is_leaf(tree, is_leaf, none_is_leaf, namespace is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_round_trip_is_leaf(tree, is_leaf, none_is_leaf, namespace): - subtrees, treespec = optree.tree_flatten( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - actual = optree.tree_unflatten(treespec, subtrees) - assert actual == tree +def test_round_trip_is_leaf( + tree, + is_leaf, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + subtrees, treespec = optree.tree_flatten( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + actual = optree.tree_unflatten(treespec, subtrees) + assert actual == tree @parametrize( tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_tree_is_leaf_with_trees(tree, none_is_leaf, namespace): - leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) - for leaf in leaves: - assert optree.tree_is_leaf(leaf, none_is_leaf=none_is_leaf, namespace=namespace) - if [tree] != leaves: - assert not optree.tree_is_leaf(tree, none_is_leaf=none_is_leaf, namespace=namespace) - else: - assert optree.tree_is_leaf(tree, none_is_leaf=none_is_leaf, namespace=namespace) +def test_tree_is_leaf_with_trees( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) + for leaf in leaves: + assert optree.tree_is_leaf(leaf, none_is_leaf=none_is_leaf, namespace=namespace) + if [tree] != leaves: + assert not optree.tree_is_leaf(tree, none_is_leaf=none_is_leaf, namespace=namespace) + else: + assert optree.tree_is_leaf(tree, none_is_leaf=none_is_leaf, namespace=namespace) @parametrize( leaf=LEAVES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_tree_is_leaf_with_leaves(leaf, none_is_leaf, namespace): - assert optree.tree_is_leaf(leaf, none_is_leaf=none_is_leaf, namespace=namespace) +def test_tree_is_leaf_with_leaves( + leaf, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + assert optree.tree_is_leaf(leaf, none_is_leaf=none_is_leaf, namespace=namespace) @parametrize( @@ -500,56 +586,93 @@ def test_tree_is_leaf_with_leaves(leaf, none_is_leaf, namespace): is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_tree_is_leaf_with_is_leaf(tree, is_leaf, none_is_leaf, namespace): - leaves = optree.tree_leaves( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - for leaf in leaves: - assert optree.tree_is_leaf( - leaf, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - if [tree] != leaves: - assert not optree.tree_is_leaf( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - else: - assert optree.tree_is_leaf( +def test_tree_is_leaf_with_is_leaf( + tree, + is_leaf, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves = optree.tree_leaves( tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, ) + for leaf in leaves: + assert optree.tree_is_leaf( + leaf, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + if [tree] != leaves: + assert not optree.tree_is_leaf( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + else: + assert optree.tree_is_leaf( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) @parametrize( tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_all_leaves_with_trees(tree, none_is_leaf, namespace): - leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) - assert optree.all_leaves(leaves, none_is_leaf=none_is_leaf, namespace=namespace) - if [tree] != leaves: - assert not optree.all_leaves([tree], none_is_leaf=none_is_leaf, namespace=namespace) +def test_all_leaves_with_trees( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves = optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) + assert optree.all_leaves(leaves, none_is_leaf=none_is_leaf, namespace=namespace) + if [tree] != leaves: + assert not optree.all_leaves([tree], none_is_leaf=none_is_leaf, namespace=namespace) @parametrize( leaf=LEAVES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_all_leaves_with_leaves(leaf, none_is_leaf, namespace): - assert optree.all_leaves([leaf], none_is_leaf=none_is_leaf, namespace=namespace) +def test_all_leaves_with_leaves( + leaf, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + assert optree.all_leaves([leaf], none_is_leaf=none_is_leaf, namespace=namespace) @parametrize( @@ -557,20 +680,33 @@ def test_all_leaves_with_leaves(leaf, none_is_leaf, namespace): is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_all_leaves_with_is_leaf(tree, is_leaf, none_is_leaf, namespace): - leaves = optree.tree_leaves( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert optree.all_leaves( - leaves, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) +def test_all_leaves_with_is_leaf( + tree, + is_leaf, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves = optree.tree_leaves( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert optree.all_leaves( + leaves, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) def test_tree_map(): @@ -2509,40 +2645,58 @@ def test_tree_replace_nones(): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_tree_transpose(tree, none_is_leaf, namespace): - outer_treespec = optree.tree_structure( - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - inner_treespec = optree.tree_structure( - [1, 1, 1], - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - nested = optree.tree_map( - lambda x: [x, x, x], - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - if outer_treespec.num_leaves == 0: - with pytest.raises(ValueError, match='Tree structures must have at least one leaf.'): - optree.tree_transpose(outer_treespec, inner_treespec, nested) - return - with pytest.raises(ValueError, match='Tree structures must have the same none_is_leaf value.'): - optree.tree_transpose( - outer_treespec, - optree.tree_structure( - [1, 1, 1], - none_is_leaf=not none_is_leaf, - namespace=namespace, - ), - nested, +def test_tree_transpose( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + outer_treespec = optree.tree_structure( + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, ) - actual = optree.tree_transpose(outer_treespec, inner_treespec, nested) - assert actual == [tree, tree, tree] + inner_treespec = optree.tree_structure( + [1, 1, 1], + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + nested = optree.tree_map( + lambda x: [x, x, x], + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + if outer_treespec.num_leaves == 0: + with pytest.raises( + ValueError, + match=re.escape('Tree structures must have at least one leaf.'), + ): + optree.tree_transpose(outer_treespec, inner_treespec, nested) + return + with pytest.raises( + ValueError, + match=re.escape('Tree structures must have the same none_is_leaf value.'), + ): + optree.tree_transpose( + outer_treespec, + optree.tree_structure( + [1, 1, 1], + none_is_leaf=not none_is_leaf, + namespace=namespace, + ), + nested, + ) + actual = optree.tree_transpose(outer_treespec, inner_treespec, nested) + assert actual == [tree, tree, tree] def test_tree_transpose_mismatch_outer(): @@ -2961,111 +3115,133 @@ def test_tree_any(): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_tree_flatten_one_level(tree, none_is_leaf, namespace): # noqa: C901 - actual_leaves = [] - actual_paths = [] - actual_typed_paths = [] - - path_stack = [] - typed_path_stack = [] - - def flatten(node): # noqa: C901 - counter = itertools.count() - expected_children, one_level_treespec = optree.tree_flatten( - node, - is_leaf=lambda x: next(counter) > 0, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - node_type = type(node) - node_kind = one_level_treespec.kind - if one_level_treespec.is_leaf(): - assert expected_children == [node] - assert node_kind == optree.PyTreeKind.LEAF - with pytest.raises( - ValueError, - match=re.escape(f'Cannot flatten leaf-type: {node_type} (node: {node!r}).'), - ): - optree.tree_flatten_one_level(node, none_is_leaf=none_is_leaf, namespace=namespace) - actual_leaves.append(node) - actual_paths.append(tuple(path_stack)) - actual_typed_paths.append(tuple(typed_path_stack)) - else: - children, metadata, entries, unflatten_func = optree.tree_flatten_one_level( +def test_tree_flatten_one_level( # noqa: C901 + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + actual_leaves = [] + actual_paths = [] + actual_typed_paths = [] + + path_stack = [] + typed_path_stack = [] + + def flatten(node): # noqa: C901 + counter = itertools.count() + expected_children, one_level_treespec = optree.tree_flatten( node, + is_leaf=lambda x: next(counter) > 0, none_is_leaf=none_is_leaf, namespace=namespace, ) - assert children == expected_children - if node_type in {type(None), tuple, list}: - assert metadata is None - if node_type is tuple: - assert node_kind == optree.PyTreeKind.TUPLE - elif node_type is list: - assert node_kind == optree.PyTreeKind.LIST - else: - assert node_kind == optree.PyTreeKind.NONE - elif node_type is dict: - assert metadata == sorted(node.keys()) - assert node_kind == optree.PyTreeKind.DICT - elif node_type is OrderedDict: - assert metadata == list(node.keys()) - assert node_kind == optree.PyTreeKind.ORDEREDDICT - elif node_type is defaultdict: - assert metadata == (node.default_factory, sorted(node.keys())) - assert node_kind == optree.PyTreeKind.DEFAULTDICT - elif node_type is deque: - assert metadata == node.maxlen - assert node_kind == optree.PyTreeKind.DEQUE - elif optree.is_structseq(node): - assert optree.is_structseq_class(node_type) - assert isinstance(node, optree.typing.structseq) - assert issubclass(node_type, optree.typing.structseq) - assert metadata is node_type - assert node_kind == optree.PyTreeKind.STRUCTSEQUENCE - elif optree.is_namedtuple(node): - assert optree.is_namedtuple_class(node_type) - assert metadata is node_type - assert node_kind == optree.PyTreeKind.NAMEDTUPLE + node_type = type(node) + node_kind = one_level_treespec.kind + if one_level_treespec.is_leaf(): + assert expected_children == [node] + assert node_kind == optree.PyTreeKind.LEAF + with pytest.raises( + ValueError, + match=re.escape(f'Cannot flatten leaf-type: {node_type} (node: {node!r}).'), + ): + optree.tree_flatten_one_level( + node, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + actual_leaves.append(node) + actual_paths.append(tuple(path_stack)) + actual_typed_paths.append(tuple(typed_path_stack)) else: - assert node_kind == optree.PyTreeKind.CUSTOM - assert len(entries) == len(children) - if hasattr(node, '__getitem__'): + children, metadata, entries, unflatten_func = optree.tree_flatten_one_level( + node, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert children == expected_children + if node_type in {type(None), tuple, list}: + assert metadata is None + if node_type is tuple: + assert node_kind == optree.PyTreeKind.TUPLE + elif node_type is list: + assert node_kind == optree.PyTreeKind.LIST + else: + assert node_kind == optree.PyTreeKind.NONE + elif node_type is dict: + if dict_should_be_sorted or dict_session_namespace not in {'', namespace}: + assert metadata == sorted(node.keys()) + else: + assert metadata == list(node.keys()) + assert node_kind == optree.PyTreeKind.DICT + elif node_type is OrderedDict: + assert metadata == list(node.keys()) + assert node_kind == optree.PyTreeKind.ORDEREDDICT + elif node_type is defaultdict: + if dict_should_be_sorted or dict_session_namespace not in {'', namespace}: + assert metadata == (node.default_factory, sorted(node.keys())) + else: + assert metadata == (node.default_factory, list(node.keys())) + assert node_kind == optree.PyTreeKind.DEFAULTDICT + elif node_type is deque: + assert metadata == node.maxlen + assert node_kind == optree.PyTreeKind.DEQUE + elif optree.is_structseq(node): + assert optree.is_structseq_class(node_type) + assert isinstance(node, optree.typing.structseq) + assert issubclass(node_type, optree.typing.structseq) + assert metadata is node_type + assert node_kind == optree.PyTreeKind.STRUCTSEQUENCE + elif optree.is_namedtuple(node): + assert optree.is_namedtuple_class(node_type) + assert metadata is node_type + assert node_kind == optree.PyTreeKind.NAMEDTUPLE + else: + assert node_kind == optree.PyTreeKind.CUSTOM + assert len(entries) == len(children) + if hasattr(node, '__getitem__'): + for child, entry in zip(children, entries): + assert node[entry] is child + + assert unflatten_func(metadata, children) == node + if node_type is type(None): + assert unflatten_func(metadata, []) is None + with pytest.raises(ValueError, match=re.escape('Expected no children.')): + unflatten_func(metadata, range(1)) + for child, entry in zip(children, entries): - assert node[entry] is child - - assert unflatten_func(metadata, children) == node - if node_type is type(None): - assert unflatten_func(metadata, []) is None - with pytest.raises(ValueError, match=re.escape('Expected no children.')): - unflatten_func(metadata, range(1)) - - for child, entry in zip(children, entries): - path_stack.append(entry) - typed_path_stack.append((entry, node_type, node_kind)) - flatten(child) - path_stack.pop() - typed_path_stack.pop() - - flatten(tree) - assert len(path_stack) == 0 - assert len(typed_path_stack) == 0 - assert actual_leaves == optree.tree_leaves( - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert actual_paths == optree.tree_paths( - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert actual_typed_paths == [ - tuple((e.entry, e.type, e.kind) for e in accessor) - for accessor in optree.tree_accessors( + path_stack.append(entry) + typed_path_stack.append((entry, node_type, node_kind)) + flatten(child) + path_stack.pop() + typed_path_stack.pop() + + flatten(tree) + assert len(path_stack) == 0 + assert len(typed_path_stack) == 0 + assert actual_leaves == optree.tree_leaves( tree, none_is_leaf=none_is_leaf, namespace=namespace, ) - ] + assert actual_paths == optree.tree_paths( + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert actual_typed_paths == [ + tuple((e.entry, e.type, e.kind) for e in accessor) + for accessor in optree.tree_accessors( + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + ] diff --git a/tests/test_prefix_errors.py b/tests/test_prefix_errors.py index 379d9565..b2910b69 100644 --- a/tests/test_prefix_errors.py +++ b/tests/test_prefix_errors.py @@ -23,7 +23,15 @@ import pytest import optree -from helpers import TREES, CustomTuple, FlatCache, TimeStructTimeType, Vector2D, parametrize +from helpers import ( + GLOBAL_NAMESPACE, + TREES, + CustomTuple, + FlatCache, + TimeStructTimeType, + Vector2D, + parametrize, +) from optree.registry import ( AttributeKeyPathEntry, FlattenedKeyPathEntry, @@ -406,78 +414,90 @@ def test_different_metadata_multiple(): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_standard_dictionary(tree, none_is_leaf, namespace): - random.seed(0) +def test_standard_dictionary( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + random.seed(0) - def build_subtree(x): - return random.choice([x, (x,), [x, x], (x, [x]), {'a': x, 'b': [x]}]) + def build_subtree(x): + return random.choice([x, (x,), [x, x], (x, [x]), {'a': x, 'b': [x]}]) - suffix_tree = optree.tree_map( - build_subtree, - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - treespec = optree.tree_structure( - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + suffix_tree = optree.tree_map( + build_subtree, + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + treespec = optree.tree_structure( + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) - if 'FlatCache' in str(treespec): - return - - def shuffle_dictionary(x): - if type(x) in {dict, OrderedDict, defaultdict}: - items = list(x.items()) - random.shuffle(items) - dict_type = random.choice([dict, OrderedDict, defaultdict]) - if dict_type is defaultdict: - return defaultdict(getattr(x, 'default_factory', int), items) - return dict_type(items) - return x - - shuffled_tree = optree.tree_map( - shuffle_dictionary, - tree, - is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - shuffled_treespec = optree.tree_structure( - shuffled_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - shuffled_suffix_tree = optree.tree_map( - shuffle_dictionary, - suffix_tree, - is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - shuffled_suffix_treespec = optree.tree_structure( - shuffled_suffix_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + if 'FlatCache' in str(treespec): + return + + def shuffle_dictionary(x): + if type(x) in {dict, OrderedDict, defaultdict}: + items = list(x.items()) + random.shuffle(items) + dict_type = random.choice([dict, OrderedDict, defaultdict]) + if dict_type is defaultdict: + return defaultdict(getattr(x, 'default_factory', int), items) + return dict_type(items) + return x + + shuffled_tree = optree.tree_map( + shuffle_dictionary, + tree, + is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + shuffled_treespec = optree.tree_structure( + shuffled_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + shuffled_suffix_tree = optree.tree_map( + shuffle_dictionary, + suffix_tree, + is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + shuffled_suffix_treespec = optree.tree_structure( + shuffled_suffix_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) - # Ignore dictionary types and key ordering - optree.tree_map_( - lambda x, y: None, - shuffled_tree, - shuffled_suffix_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert shuffled_treespec.is_prefix(shuffled_suffix_treespec) - () == optree.prefix_errors( # noqa: B015 - shuffled_tree, - shuffled_suffix_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + # Ignore dictionary types and key ordering + optree.tree_map_( + lambda x, y: None, + shuffled_tree, + shuffled_suffix_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert shuffled_treespec.is_prefix(shuffled_suffix_treespec) + () == optree.prefix_errors( # noqa: B015 + shuffled_tree, + shuffled_suffix_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) def test_namedtuple(): diff --git a/tests/test_registry.py b/tests/test_registry.py index dcd86ad5..a7bf3271 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -22,7 +22,7 @@ import pytest import optree -from helpers import gc_collect, skipif_pypy +from helpers import GLOBAL_NAMESPACE, gc_collect, skipif_pypy def test_register_pytree_node_class_with_no_namespace(): @@ -60,7 +60,7 @@ def tree_unflatten(cls, metadata, children): def test_register_pytree_node_with_non_class(): with pytest.raises(TypeError, match='Expected a class'): - @optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) + @optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) def func1(): pass @@ -69,7 +69,7 @@ def func1(): 1, lambda s: (sorted(s), None, None), lambda _, s: set(s), - namespace=optree.registry.__GLOBAL_NAMESPACE, + namespace=GLOBAL_NAMESPACE, ) with pytest.raises(TypeError, match='Expected a class'): @@ -250,7 +250,7 @@ def test_register_pytree_node_duplicate_builtins(): type(None), lambda n: ((), None, None), lambda _, n: None, - namespace=optree.registry.__GLOBAL_NAMESPACE, + namespace=GLOBAL_NAMESPACE, ) with pytest.raises( @@ -276,7 +276,7 @@ def test_register_pytree_node_duplicate_builtins(): list, lambda lst: (lst, None, None), lambda _, lst: lst, - namespace=optree.registry.__GLOBAL_NAMESPACE, + namespace=GLOBAL_NAMESPACE, ) with pytest.raises( ValueError, @@ -306,7 +306,7 @@ def test_register_pytree_node_namedtuple(): mytuple1, lambda t: (reversed(t), None, None), lambda _, t: mytuple1(*reversed(t)), - namespace=optree.registry.__GLOBAL_NAMESPACE, + namespace=GLOBAL_NAMESPACE, ) with pytest.warns( UserWarning, @@ -469,7 +469,7 @@ def test_pytree_node_registry_get(): set, lambda s: (sorted(s), None, None), lambda _, s: set(s), - namespace=optree.registry.__GLOBAL_NAMESPACE, + namespace=GLOBAL_NAMESPACE, ) handler = optree.register_pytree_node.get(set) assert handler is not None @@ -568,10 +568,10 @@ def func1(): pass with pytest.raises(TypeError, match='Expected a class'): - optree.unregister_pytree_node(func1, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(func1, namespace=GLOBAL_NAMESPACE) with pytest.raises(TypeError, match='Expected a class'): - optree.unregister_pytree_node(1, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(1, namespace=GLOBAL_NAMESPACE) def func2(): pass @@ -602,7 +602,7 @@ def tree_unflatten(cls, metadata, children): ValueError, match=r"PyTree type is not registered in the global namespace\.", ): - optree.unregister_pytree_node(MyList, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(MyList, namespace=GLOBAL_NAMESPACE) optree.register_pytree_node_class(MyList, namespace='mylist') @@ -610,7 +610,7 @@ def tree_unflatten(cls, metadata, children): ValueError, match=r"PyTree type is not registered in the global namespace\.", ): - optree.unregister_pytree_node(MyList, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(MyList, namespace=GLOBAL_NAMESPACE) optree.unregister_pytree_node(MyList, namespace='mylist') @@ -636,7 +636,7 @@ def test_unregister_pytree_node_with_builtins(): r"PyTree type is a built-in type and cannot be unregistered.", ), ): - optree.unregister_pytree_node(type(None), namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(type(None), namespace=GLOBAL_NAMESPACE) with pytest.raises( ValueError, @@ -652,7 +652,7 @@ def test_unregister_pytree_node_with_builtins(): r"PyTree type is a built-in type and cannot be unregistered.", ), ): - optree.unregister_pytree_node(list, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(list, namespace=GLOBAL_NAMESPACE) with pytest.raises( ValueError, @@ -677,7 +677,7 @@ def test_unregister_pytree_node_namedtuple(): mytuple1, lambda t: (reversed(t), None, None), lambda _, t: mytuple1(*reversed(t)), - namespace=optree.registry.__GLOBAL_NAMESPACE, + namespace=GLOBAL_NAMESPACE, ) tree = mytuple1(1, 2, 3) @@ -686,7 +686,7 @@ def test_unregister_pytree_node_namedtuple(): assert str(treespec1) == 'PyTreeSpec(CustomTreeNode(mytuple1[None], [*, *, *]))' assert tree == optree.tree_unflatten(treespec1, leaves1) - optree.unregister_pytree_node(mytuple1, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(mytuple1, namespace=GLOBAL_NAMESPACE) assert str(treespec1) == 'PyTreeSpec(CustomTreeNode(mytuple1[None], [*, *, *]))' assert tree == optree.tree_unflatten(treespec1, leaves1) @@ -711,7 +711,7 @@ def test_unregister_pytree_node_namedtuple(): r'which is not explicitly registered in the global namespace.', ), ): - optree.unregister_pytree_node(mytuple1, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(mytuple1, namespace=GLOBAL_NAMESPACE) mytuple2 = namedtuple('mytuple2', ['a', 'b', 'c']) # noqa: PYI024 with pytest.warns( @@ -768,7 +768,7 @@ def test_unregister_pytree_node_namedtuple(): @skipif_pypy def test_unregister_pytree_node_memory_leak(): # noqa: C901 - @optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) + @optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class MyList1(UserList): def tree_flatten(self): return self.data, None, None @@ -780,12 +780,12 @@ def tree_unflatten(cls, metadata, children): wr = weakref.ref(MyList1) assert wr() is not None - optree.unregister_pytree_node(MyList1, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(MyList1, namespace=GLOBAL_NAMESPACE) del MyList1 gc_collect() assert wr() is None - @optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) + @optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class MyList2(UserList): def tree_flatten(self): return reversed(self.data), None, None @@ -801,7 +801,7 @@ def tree_unflatten(cls, metadata, children): assert leaves == [3, 2, 1] assert str(treespec) == 'PyTreeSpec(CustomTreeNode(MyList2[None], [*, *, *]))' - optree.unregister_pytree_node(MyList2, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(MyList2, namespace=GLOBAL_NAMESPACE) del MyList2 gc_collect() assert wr() is not None @@ -812,7 +812,7 @@ def tree_unflatten(cls, metadata, children): gc_collect() assert wr() is None - @optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) + @optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class MyList3(UserList): def tree_flatten(self): return reversed(self.data), None, None @@ -831,7 +831,7 @@ def tree_unflatten(cls, metadata, children): == "PyTreeSpec(CustomTreeNode(MyList3[None], [*, *, *]), namespace='undefined')" ) - optree.unregister_pytree_node(MyList3, namespace=optree.registry.__GLOBAL_NAMESPACE) + optree.unregister_pytree_node(MyList3, namespace=GLOBAL_NAMESPACE) del MyList3 gc_collect() assert wr() is not None @@ -887,3 +887,16 @@ def tree_unflatten(cls, metadata, children): del treespec gc_collect() assert wr() is None + + +def test_dict_insertion_order_with_invalid_namespace(): + with ( + pytest.raises(TypeError, match='The namespace must be a string'), + optree.dict_insertion_ordered(True, namespace=1), + ): + pass + with ( + pytest.raises(ValueError, match='The namespace cannot be an empty string.'), + optree.dict_insertion_ordered(True, namespace=''), + ): + pass diff --git a/tests/test_treespec.py b/tests/test_treespec.py index ed341aab..ededc39b 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -29,6 +29,7 @@ import helpers import optree from helpers import ( + GLOBAL_NAMESPACE, NAMESPACED_TREE, PYPY, TREE_STRINGS, @@ -68,53 +69,65 @@ def test_treespec_equal_hash(): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_rich_compare(tree, none_is_leaf, namespace): - count = itertools.count() - - def build_subtree(x): - cnt = next(count) - if cnt % 4 == 0: - return (x,) - if cnt % 4 == 1: - return [x, x] - if cnt % 4 == 2: - return (x, [x]) - return {'a': x, 'b': [x]} - - treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) - suffix_treespec = optree.tree_structure( - optree.tree_map(build_subtree, tree, none_is_leaf=none_is_leaf, namespace=namespace), - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert treespec == treespec - assert not (treespec != treespec) - assert not (treespec < treespec) - assert not (treespec > treespec) - assert treespec <= treespec - assert treespec >= treespec - assert optree.treespec_is_prefix(treespec, treespec, strict=False) - assert not optree.treespec_is_prefix(treespec, treespec, strict=True) - assert optree.treespec_is_suffix(treespec, treespec, strict=False) - assert not optree.treespec_is_suffix(treespec, treespec, strict=True) - - if 'FlatCache' in str(treespec) or treespec == suffix_treespec: - return - - assert treespec != suffix_treespec - assert not (treespec == suffix_treespec) - assert treespec != suffix_treespec - assert treespec < suffix_treespec - assert not (treespec > suffix_treespec) - assert treespec <= suffix_treespec - assert not (treespec >= suffix_treespec) - assert suffix_treespec != treespec - assert not (suffix_treespec == treespec) - assert suffix_treespec > treespec - assert not (suffix_treespec < treespec) - assert suffix_treespec >= treespec - assert not (suffix_treespec <= treespec) +def test_treespec_rich_compare( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + count = itertools.count() + + def build_subtree(x): + cnt = next(count) + if cnt % 4 == 0: + return (x,) + if cnt % 4 == 1: + return [x, x] + if cnt % 4 == 2: + return (x, [x]) + return {'a': x, 'b': [x]} + + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + suffix_treespec = optree.tree_structure( + optree.tree_map(build_subtree, tree, none_is_leaf=none_is_leaf, namespace=namespace), + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert treespec == treespec + assert not (treespec != treespec) + assert not (treespec < treespec) + assert not (treespec > treespec) + assert treespec <= treespec + assert treespec >= treespec + assert optree.treespec_is_prefix(treespec, treespec, strict=False) + assert not optree.treespec_is_prefix(treespec, treespec, strict=True) + assert optree.treespec_is_suffix(treespec, treespec, strict=False) + assert not optree.treespec_is_suffix(treespec, treespec, strict=True) + + if 'FlatCache' in str(treespec) or treespec == suffix_treespec: + return + + assert treespec != suffix_treespec + assert not (treespec == suffix_treespec) + assert treespec != suffix_treespec + assert treespec < suffix_treespec + assert not (treespec > suffix_treespec) + assert treespec <= suffix_treespec + assert not (treespec >= suffix_treespec) + assert suffix_treespec != treespec + assert not (suffix_treespec == treespec) + assert suffix_treespec > treespec + assert not (suffix_treespec < treespec) + assert suffix_treespec >= treespec + assert not (suffix_treespec <= treespec) @parametrize( @@ -407,21 +420,33 @@ def test_treespec_with_namespace(): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_pickle_round_trip(tree, none_is_leaf, namespace): - expected = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) - try: - pickle.loads(pickle.dumps(tree)) - except pickle.PicklingError: - with pytest.raises(pickle.PicklingError, match=r"Can't pickle .*:"): - pickle.loads(pickle.dumps(expected)) - else: - actual = pickle.loads(pickle.dumps(expected)) - assert actual == expected - if expected.type in {dict, OrderedDict, defaultdict}: - assert list(optree.tree_unflatten(actual, range(len(actual)))) == list( - optree.tree_unflatten(expected, range(len(expected))), - ) +def test_treespec_pickle_round_trip( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + expected = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + try: + pickle.loads(pickle.dumps(tree)) + except pickle.PicklingError: + with pytest.raises(pickle.PicklingError, match=r"Can't pickle .*:"): + pickle.loads(pickle.dumps(expected)) + else: + actual = pickle.loads(pickle.dumps(expected)) + assert actual == expected + if expected.type in {dict, OrderedDict, defaultdict}: + assert list(optree.tree_unflatten(actual, range(len(actual)))) == list( + optree.tree_unflatten(expected, range(len(expected))), + ) class Foo: @@ -478,13 +503,25 @@ def test_treespec_pickle_missing_registration(): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_type(tree, none_is_leaf, namespace): - treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) - if treespec.is_leaf(): - assert treespec.type is None - else: - assert type(tree) is treespec.type +def test_treespec_type( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + if treespec.is_leaf(): + assert treespec.type is None + else: + assert type(tree) is treespec.type @parametrize( @@ -502,209 +539,267 @@ def test_treespec_type(tree, none_is_leaf, namespace): ], none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_compose_children(tree, inner_tree, none_is_leaf, namespace): - treespec = optree.tree_structure( - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - inner_treespec = optree.tree_structure( - inner_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - expected_treespec = optree.tree_structure( - optree.tree_map( - lambda _: inner_tree, +def test_treespec_compose_children( + tree, + inner_tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure( tree, none_is_leaf=none_is_leaf, namespace=namespace, - ), - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - composed_treespec = treespec.compose(inner_treespec) - expected_leaves = treespec.num_leaves * inner_treespec.num_leaves - assert composed_treespec.num_leaves == treespec.num_leaves * inner_treespec.num_leaves - expected_nodes = (treespec.num_nodes - treespec.num_leaves) + ( - inner_treespec.num_nodes * treespec.num_leaves - ) - assert composed_treespec.num_nodes == expected_nodes - leaves = list(range(expected_leaves)) - composed = optree.tree_unflatten(composed_treespec, leaves) - assert leaves == optree.tree_leaves( - composed, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + ) + inner_treespec = optree.tree_structure( + inner_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + expected_treespec = optree.tree_structure( + optree.tree_map( + lambda _: inner_tree, + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ), + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + composed_treespec = treespec.compose(inner_treespec) + expected_leaves = treespec.num_leaves * inner_treespec.num_leaves + assert composed_treespec.num_leaves == treespec.num_leaves * inner_treespec.num_leaves + expected_nodes = (treespec.num_nodes - treespec.num_leaves) + ( + inner_treespec.num_nodes * treespec.num_leaves + ) + assert composed_treespec.num_nodes == expected_nodes + leaves = list(range(expected_leaves)) + composed = optree.tree_unflatten(composed_treespec, leaves) + assert leaves == optree.tree_leaves( + composed, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) - if 'FlatCache' in str(treespec): - return + if 'FlatCache' in str(treespec): + return - assert composed_treespec == expected_treespec + assert composed_treespec == expected_treespec - stack = [(composed_treespec.children(), expected_treespec.children())] - while stack: - composed_children, expected_children = stack.pop() - for composed_child, expected_child in zip(composed_children, expected_children): - assert composed_child == expected_child - stack.append((composed_child.children(), expected_child.children())) + stack = [(composed_treespec.children(), expected_treespec.children())] + while stack: + composed_children, expected_children = stack.pop() + for composed_child, expected_child in zip(composed_children, expected_children): + assert composed_child == expected_child + stack.append((composed_child.children(), expected_child.children())) - assert composed_treespec == optree.tree_structure( - composed, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + assert composed_treespec == optree.tree_structure( + composed, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) - if treespec == expected_treespec: - assert not (treespec != expected_treespec) - assert not (treespec < expected_treespec) - assert treespec <= expected_treespec - assert not (treespec > expected_treespec) - assert treespec >= expected_treespec - assert expected_treespec >= treespec - assert not (expected_treespec > treespec) - assert expected_treespec <= treespec - assert not (expected_treespec < treespec) - assert not optree.treespec_is_prefix(treespec, expected_treespec, strict=True) - assert optree.treespec_is_prefix(treespec, expected_treespec, strict=False) - assert not optree.treespec_is_suffix(treespec, expected_treespec, strict=True) - assert optree.treespec_is_suffix(treespec, expected_treespec, strict=False) - assert not optree.treespec_is_prefix(expected_treespec, treespec, strict=True) - assert optree.treespec_is_prefix(expected_treespec, treespec, strict=False) - assert not optree.treespec_is_suffix(expected_treespec, treespec, strict=True) - assert optree.treespec_is_suffix(expected_treespec, treespec, strict=False) - else: - assert treespec != expected_treespec - assert treespec < expected_treespec - assert treespec <= expected_treespec - assert not (treespec > expected_treespec) - assert not (treespec >= expected_treespec) - assert expected_treespec >= treespec - assert expected_treespec > treespec - assert not (expected_treespec <= treespec) - assert not (expected_treespec < treespec) - assert optree.treespec_is_prefix(treespec, expected_treespec, strict=True) - assert optree.treespec_is_prefix(treespec, expected_treespec, strict=False) - assert not optree.treespec_is_suffix(treespec, expected_treespec, strict=True) - assert not optree.treespec_is_suffix(treespec, expected_treespec, strict=False) - assert not optree.treespec_is_prefix(expected_treespec, treespec, strict=True) - assert not optree.treespec_is_prefix(expected_treespec, treespec, strict=False) - assert optree.treespec_is_suffix(expected_treespec, treespec, strict=True) - assert optree.treespec_is_suffix(expected_treespec, treespec, strict=False) + if treespec == expected_treespec: + assert not (treespec != expected_treespec) + assert not (treespec < expected_treespec) + assert treespec <= expected_treespec + assert not (treespec > expected_treespec) + assert treespec >= expected_treespec + assert expected_treespec >= treespec + assert not (expected_treespec > treespec) + assert expected_treespec <= treespec + assert not (expected_treespec < treespec) + assert not optree.treespec_is_prefix(treespec, expected_treespec, strict=True) + assert optree.treespec_is_prefix(treespec, expected_treespec, strict=False) + assert not optree.treespec_is_suffix(treespec, expected_treespec, strict=True) + assert optree.treespec_is_suffix(treespec, expected_treespec, strict=False) + assert not optree.treespec_is_prefix(expected_treespec, treespec, strict=True) + assert optree.treespec_is_prefix(expected_treespec, treespec, strict=False) + assert not optree.treespec_is_suffix(expected_treespec, treespec, strict=True) + assert optree.treespec_is_suffix(expected_treespec, treespec, strict=False) + else: + assert treespec != expected_treespec + assert treespec < expected_treespec + assert treespec <= expected_treespec + assert not (treespec > expected_treespec) + assert not (treespec >= expected_treespec) + assert expected_treespec >= treespec + assert expected_treespec > treespec + assert not (expected_treespec <= treespec) + assert not (expected_treespec < treespec) + assert optree.treespec_is_prefix(treespec, expected_treespec, strict=True) + assert optree.treespec_is_prefix(treespec, expected_treespec, strict=False) + assert not optree.treespec_is_suffix(treespec, expected_treespec, strict=True) + assert not optree.treespec_is_suffix(treespec, expected_treespec, strict=False) + assert not optree.treespec_is_prefix(expected_treespec, treespec, strict=True) + assert not optree.treespec_is_prefix(expected_treespec, treespec, strict=False) + assert optree.treespec_is_suffix(expected_treespec, treespec, strict=True) + assert optree.treespec_is_suffix(expected_treespec, treespec, strict=False) @parametrize( tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_entries(tree, none_is_leaf, namespace): - expected_paths, _, treespec = optree.tree_flatten_with_path( - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert optree.treespec_paths(treespec) == expected_paths - - def gen_path(spec): - entries = optree.treespec_entries(spec) - children = optree.treespec_children(spec) - assert len(entries) == spec.num_children - assert len(children) == spec.num_children - assert entries is not optree.treespec_entries(spec) - assert children is not optree.treespec_children(spec) - optree.treespec_entries(spec).clear() - optree.treespec_children(spec).clear() - - if spec.is_leaf(): - assert spec.num_children == 0 - yield () - return - - for entry, child in zip(entries, children): - for suffix in gen_path(child): - yield (entry, *suffix) - - paths = list(gen_path(treespec)) - assert paths == expected_paths - - expected_accessors, _, other_treespec = optree.tree_flatten_with_accessor( - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - assert optree.treespec_accessors(treespec) == expected_accessors - assert optree.treespec_accessors(other_treespec) == expected_accessors - assert treespec == other_treespec - - def gen_typed_path(spec): - entries = optree.treespec_entries(spec) - children = optree.treespec_children(spec) - assert len(entries) == spec.num_children - assert len(children) == spec.num_children - - if spec.is_leaf(): - assert spec.num_children == 0 - yield () - return - - node_type = spec.type - node_kind = spec.kind - for entry, child in zip(entries, children): - for suffix in gen_typed_path(child): - yield ((entry, node_type, node_kind), *suffix) - - typed_paths = list(gen_typed_path(treespec)) - expected_typed_paths = [ - tuple((e.entry, e.type, e.kind) for e in accessor) for accessor in expected_accessors - ] - assert typed_paths == expected_typed_paths +def test_treespec_entries( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + expected_paths, _, treespec = optree.tree_flatten_with_path( + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert optree.treespec_paths(treespec) == expected_paths + + def gen_path(spec): + entries = optree.treespec_entries(spec) + children = optree.treespec_children(spec) + assert len(entries) == spec.num_children + assert len(children) == spec.num_children + assert entries is not optree.treespec_entries(spec) + assert children is not optree.treespec_children(spec) + optree.treespec_entries(spec).clear() + optree.treespec_children(spec).clear() + + if spec.is_leaf(): + assert spec.num_children == 0 + yield () + return + + for entry, child in zip(entries, children): + for suffix in gen_path(child): + yield (entry, *suffix) + + paths = list(gen_path(treespec)) + assert paths == expected_paths + + expected_accessors, _, other_treespec = optree.tree_flatten_with_accessor( + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert optree.treespec_accessors(treespec) == expected_accessors + assert optree.treespec_accessors(other_treespec) == expected_accessors + assert treespec == other_treespec + + def gen_typed_path(spec): + entries = optree.treespec_entries(spec) + children = optree.treespec_children(spec) + assert len(entries) == spec.num_children + assert len(children) == spec.num_children + + if spec.is_leaf(): + assert spec.num_children == 0 + yield () + return + + node_type = spec.type + node_kind = spec.kind + for entry, child in zip(entries, children): + for suffix in gen_typed_path(child): + yield ((entry, node_type, node_kind), *suffix) + + typed_paths = list(gen_typed_path(treespec)) + expected_typed_paths = [ + tuple((e.entry, e.type, e.kind) for e in accessor) for accessor in expected_accessors + ] + assert typed_paths == expected_typed_paths @parametrize( tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_entry(tree, none_is_leaf, namespace): - treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) - if treespec.type is None or treespec.type is type(None): - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, 0) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, -1) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, 1) - if treespec.is_leaf(strict=False): - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, 0) +def test_treespec_entry( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + if treespec.type is None or treespec.type is type(None): + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Entry() index out of range.'), + ): + optree.treespec_entry(treespec, 0) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Entry() index out of range.'), + ): + optree.treespec_entry(treespec, -1) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Entry() index out of range.'), + ): + optree.treespec_entry(treespec, 1) + if treespec.is_leaf(strict=False): + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Entry() index out of range.'), + ): + optree.treespec_entry(treespec, 0) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Entry() index out of range.'), + ): + optree.treespec_entry(treespec, -1) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Entry() index out of range.'), + ): + optree.treespec_entry(treespec, 1) + expected_entries = optree.treespec_entries(treespec) + for i, entry in enumerate(expected_entries): + assert entry == optree.treespec_entry(treespec, i) + assert entry == optree.treespec_entry(treespec, i - len(expected_entries)) + assert optree.treespec_entry(treespec, i) == optree.treespec_entry(treespec, i) + assert optree.treespec_entry( + treespec, + i - len(expected_entries), + ) == optree.treespec_entry( + treespec, + i - len(expected_entries), + ) + assert optree.treespec_entry(treespec, i) == optree.treespec_entry( + treespec, + i - len(expected_entries), + ) with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, -1) + optree.treespec_entry(treespec, len(expected_entries)) with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, 1) - expected_entries = optree.treespec_entries(treespec) - for i, entry in enumerate(expected_entries): - assert entry == optree.treespec_entry(treespec, i) - assert entry == optree.treespec_entry(treespec, i - len(expected_entries)) - assert optree.treespec_entry(treespec, i) == optree.treespec_entry(treespec, i) - assert optree.treespec_entry(treespec, i - len(expected_entries)) == optree.treespec_entry( - treespec, - i - len(expected_entries), - ) - assert optree.treespec_entry(treespec, i) == optree.treespec_entry( - treespec, - i - len(expected_entries), - ) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, len(expected_entries)) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Entry() index out of range.')): - optree.treespec_entry(treespec, -len(expected_entries) - 1) + optree.treespec_entry(treespec, -len(expected_entries) - 1) - assert expected_entries == [ - optree.treespec_entry(treespec, i) for i in range(len(expected_entries)) - ] + assert expected_entries == [ + optree.treespec_entry(treespec, i) for i in range(len(expected_entries)) + ] def test_treespec_children(): @@ -733,86 +828,155 @@ def test_treespec_children(): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_child(tree, none_is_leaf, namespace): - treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) - if treespec.type is None or treespec.type is type(None): - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, 0) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, -1) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, 1) - if treespec.is_leaf(strict=False): - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, 0) +def test_treespec_child( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + if treespec.type is None or treespec.type is type(None): + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Child() index out of range.'), + ): + optree.treespec_child(treespec, 0) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Child() index out of range.'), + ): + optree.treespec_child(treespec, -1) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Child() index out of range.'), + ): + optree.treespec_child(treespec, 1) + if treespec.is_leaf(strict=False): + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Child() index out of range.'), + ): + optree.treespec_child(treespec, 0) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Child() index out of range.'), + ): + optree.treespec_child(treespec, -1) + with pytest.raises( + IndexError, + match=re.escape('PyTreeSpec::Child() index out of range.'), + ): + optree.treespec_child(treespec, 1) + expected_children = optree.treespec_children(treespec) + for i, child in enumerate(expected_children): + assert child == optree.treespec_child(treespec, i) + assert child == optree.treespec_child(treespec, i - len(expected_children)) + assert optree.treespec_child(treespec, i) == optree.treespec_child(treespec, i) + assert optree.treespec_child( + treespec, + i - len(expected_children), + ) == optree.treespec_child( + treespec, + i - len(expected_children), + ) + assert optree.treespec_child(treespec, i) == optree.treespec_child( + treespec, + i - len(expected_children), + ) with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, -1) + optree.treespec_child(treespec, len(expected_children)) with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, 1) - expected_children = optree.treespec_children(treespec) - for i, child in enumerate(expected_children): - assert child == optree.treespec_child(treespec, i) - assert child == optree.treespec_child(treespec, i - len(expected_children)) - assert optree.treespec_child(treespec, i) == optree.treespec_child(treespec, i) - assert optree.treespec_child(treespec, i - len(expected_children)) == optree.treespec_child( - treespec, - i - len(expected_children), - ) - assert optree.treespec_child(treespec, i) == optree.treespec_child( - treespec, - i - len(expected_children), - ) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, len(expected_children)) - with pytest.raises(IndexError, match=re.escape('PyTreeSpec::Child() index out of range.')): - optree.treespec_child(treespec, -len(expected_children) - 1) + optree.treespec_child(treespec, -len(expected_children) - 1) - assert expected_children == [ - optree.treespec_child(treespec, i) for i in range(len(expected_children)) - ] + assert expected_children == [ + optree.treespec_child(treespec, i) for i in range(len(expected_children)) + ] @parametrize( tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_num_nodes(tree, none_is_leaf, namespace): - treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) - nodes = [] - stack = [treespec] - while stack: - spec = stack.pop() - nodes.append(spec) - children = spec.children() - stack.extend(reversed(children)) - assert spec.num_nodes == sum(child.num_nodes for child in children) + 1 - assert treespec.num_nodes == len(nodes) +def test_treespec_num_nodes( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + nodes = [] + stack = [treespec] + while stack: + spec = stack.pop() + nodes.append(spec) + children = spec.children() + stack.extend(reversed(children)) + assert spec.num_nodes == sum(child.num_nodes for child in children) + 1 + assert treespec.num_nodes == len(nodes) @parametrize( tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_num_leaves(tree, none_is_leaf, namespace): - leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) - assert treespec.num_leaves == len(leaves) - assert treespec.num_leaves == len(treespec) - assert treespec.num_leaves == len(treespec.paths()) - assert treespec.num_leaves == len(treespec.accessors()) +def test_treespec_num_leaves( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf, namespace=namespace) + assert treespec.num_leaves == len(leaves) + assert treespec.num_leaves == len(treespec) + assert treespec.num_leaves == len(treespec.paths()) + assert treespec.num_leaves == len(treespec.accessors()) @parametrize( tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_num_children(tree, none_is_leaf, namespace): - treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) - assert treespec.num_children == len(treespec.entries()) - assert treespec.num_children == len(treespec.children()) +def test_treespec_num_children( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + treespec = optree.tree_structure(tree, none_is_leaf=none_is_leaf, namespace=namespace) + assert treespec.num_children == len(treespec.entries()) + assert treespec.num_children == len(treespec.children()) def test_treespec_is_leaf(): @@ -957,73 +1121,171 @@ def test_treespec_leaf_none(namespace): tree=TREES, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], ) -def test_treespec_constructor(tree, none_is_leaf, namespace): # noqa: C901 - for passed_namespace in sorted({'', namespace}): - stack = [tree] - while stack: - node = stack.pop() - counter = itertools.count() - expected_treespec = optree.tree_structure( - node, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - children, one_level_treespec = optree.tree_flatten( - node, - is_leaf=lambda x: next(counter) > 0, # noqa: B023 - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - node_type = type(node) - if one_level_treespec.is_leaf(): - assert len(children) == 1 - with pytest.warns( - UserWarning, - match=re.escape('PyTreeSpec::MakeFromCollection() is called on a leaf.'), - ): +def test_treespec_constructor( # noqa: C901 + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + for passed_namespace in sorted({'', namespace}): + stack = [tree] + while stack: + node = stack.pop() + counter = itertools.count() + expected_treespec = optree.tree_structure( + node, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + children, one_level_treespec = optree.tree_flatten( + node, + is_leaf=lambda x: next(counter) > 0, # noqa: B023 + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + node_type = type(node) + if one_level_treespec.is_leaf(): + assert len(children) == 1 + with pytest.warns( + UserWarning, + match=re.escape('PyTreeSpec::MakeFromCollection() is called on a leaf.'), + ): + assert ( + optree.treespec_from_collection( + node, + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) assert ( - optree.treespec_from_collection( - node, + optree.treespec_leaf( none_is_leaf=none_is_leaf, namespace=passed_namespace, ) == expected_treespec ) - assert ( - optree.treespec_leaf( - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - else: - children_treespecs = [ - optree.tree_structure( - child, - none_is_leaf=none_is_leaf, - namespace=namespace, + else: + children_treespecs = [ + optree.tree_structure( + child, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + for child in children + ] + collection_of_treespecs = optree.tree_unflatten( + one_level_treespec, + children_treespecs, ) - for child in children - ] - collection_of_treespecs = optree.tree_unflatten( - one_level_treespec, - children_treespecs, - ) - assert ( - optree.treespec_from_collection( - collection_of_treespecs, - none_is_leaf=none_is_leaf, - namespace=namespace, + assert ( + optree.treespec_from_collection( + collection_of_treespecs, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + == expected_treespec ) - == expected_treespec - ) - if node_type in {type(None), tuple, list}: - if node_type is tuple: + if node_type in {type(None), tuple, list}: + if node_type is tuple: + assert ( + optree.treespec_tuple( + children_treespecs, + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + tuple(children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + elif node_type is list: + assert ( + optree.treespec_list( + children_treespecs, + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + list(children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + else: + assert len(children_treespecs) == 0 + assert ( + optree.treespec_none( + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + None, + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + elif node_type is dict: + if dict_should_be_sorted or dict_session_namespace not in {'', namespace}: + assert ( + optree.treespec_dict( + zip(sorted(node), children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + dict(zip(sorted(node), children_treespecs)), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + elif dict_session_namespace == passed_namespace: + assert ( + optree.treespec_dict( + zip(node, children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + dict(zip(node, children_treespecs)), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + elif node_type is OrderedDict: assert ( - optree.treespec_tuple( - children_treespecs, + optree.treespec_ordereddict( + zip(node, children_treespecs), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1031,16 +1293,60 @@ def test_treespec_constructor(tree, none_is_leaf, namespace): # noqa: C901 ) assert ( optree.treespec_from_collection( - tuple(children_treespecs), + OrderedDict(zip(node, children_treespecs)), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) == expected_treespec ) - elif node_type is list: + elif node_type is defaultdict: + if dict_should_be_sorted or dict_session_namespace not in {'', namespace}: + assert ( + optree.treespec_defaultdict( + node.default_factory, + zip(sorted(node), children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + defaultdict( + node.default_factory, + zip(sorted(node), children_treespecs), + ), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + elif dict_session_namespace == passed_namespace: + assert ( + optree.treespec_defaultdict( + node.default_factory, + zip(node, children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + defaultdict( + node.default_factory, + zip(node, children_treespecs), + ), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + elif node_type is deque: assert ( - optree.treespec_list( + optree.treespec_deque( children_treespecs, + maxlen=node.maxlen, none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1048,16 +1354,16 @@ def test_treespec_constructor(tree, none_is_leaf, namespace): # noqa: C901 ) assert ( optree.treespec_from_collection( - list(children_treespecs), + deque(children_treespecs, maxlen=node.maxlen), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) == expected_treespec ) - else: - assert len(children_treespecs) == 0 + elif optree.is_structseq(node): assert ( - optree.treespec_none( + optree.treespec_structseq( + node_type(children_treespecs), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1065,139 +1371,49 @@ def test_treespec_constructor(tree, none_is_leaf, namespace): # noqa: C901 ) assert ( optree.treespec_from_collection( - None, + node_type(children_treespecs), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) == expected_treespec ) - elif node_type is dict: - assert ( - optree.treespec_dict( - zip(sorted(node), children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - assert ( - optree.treespec_from_collection( - dict(zip(sorted(node), children_treespecs)), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - elif node_type is OrderedDict: - assert ( - optree.treespec_ordereddict( - zip(node, children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - assert ( - optree.treespec_from_collection( - OrderedDict(zip(node, children_treespecs)), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - elif node_type is defaultdict: - assert ( - optree.treespec_defaultdict( - node.default_factory, - zip(sorted(node), children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - assert ( - optree.treespec_from_collection( - defaultdict( - node.default_factory, - zip(sorted(node), children_treespecs), - ), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - elif node_type is deque: - assert ( - optree.treespec_deque( - children_treespecs, - maxlen=node.maxlen, - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - assert ( - optree.treespec_from_collection( - deque(children_treespecs, maxlen=node.maxlen), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - elif optree.is_structseq(node): - assert ( - optree.treespec_structseq( - node_type(children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - assert ( - optree.treespec_from_collection( - node_type(children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - with pytest.raises( - ValueError, - match=r'Expected a namedtuple of PyTreeSpec\(s\), got .*\.', - ): - optree.treespec_namedtuple( - node_type(children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - elif optree.is_namedtuple(node): - assert ( - optree.treespec_namedtuple( - node_type(*children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, - ) - == expected_treespec - ) - assert ( - optree.treespec_from_collection( - node_type(*children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, + with pytest.raises( + ValueError, + match=r'Expected a namedtuple of PyTreeSpec\(s\), got .*\.', + ): + optree.treespec_namedtuple( + node_type(children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + elif optree.is_namedtuple(node): + assert ( + optree.treespec_namedtuple( + node_type(*children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec ) - == expected_treespec - ) - with pytest.raises( - ValueError, - match=r'Expected a PyStructSequence of PyTreeSpec\(s\), got .*\.', - ): - optree.treespec_structseq( - node_type(*children_treespecs), - none_is_leaf=none_is_leaf, - namespace=passed_namespace, + assert ( + optree.treespec_from_collection( + node_type(*children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec ) + with pytest.raises( + ValueError, + match=r'Expected a PyStructSequence of PyTreeSpec\(s\), got .*\.', + ): + optree.treespec_structseq( + node_type(*children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) - stack.extend(reversed(children)) + stack.extend(reversed(children)) def test_treespec_constructor_namespace():