Skip to content

Commit

Permalink
feat(ops): add context manager to temporarily set the dictionary sort…
Browse files Browse the repository at this point in the history
…ing mode
  • Loading branch information
XuehaiPan committed Jul 3, 2024
1 parent b1521eb commit 5cf5c3b
Show file tree
Hide file tree
Showing 15 changed files with 1,502 additions and 891 deletions.
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Tree Manipulation Functions

.. autosummary::

dict_insertion_ordered
tree_flatten
tree_flatten_with_path
tree_flatten_with_accessor
Expand Down Expand Up @@ -55,6 +56,7 @@ Tree Manipulation Functions
tree_flatten_one_level
prefix_errors

.. autofunction:: dict_insertion_ordered
.. autofunction:: tree_flatten
.. autofunction:: tree_flatten_with_path
.. autofunction:: tree_flatten_with_accessor
Expand Down
37 changes: 32 additions & 5 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,23 @@ class PyTreeSpec {
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Check if should preserve the insertion order of the dictionary keys during flattening.
inline static bool IsDictInsertionOrdered(const std::string &registry_namespace) {
return (sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end() ||
sm_is_dict_insertion_ordered.find(registry_namespace) !=
sm_is_dict_insertion_ordered.end());
}

// Set the namespace to preserve the insertion order of the dictionary keys during flattening.
inline static void SetDictInsertionOrdered(const bool &mode,
const std::string &registry_namespace) {
if (mode) [[likely]] {
sm_is_dict_insertion_ordered.insert(registry_namespace);
} else [[unlikely]] {
sm_is_dict_insertion_ordered.erase(registry_namespace);
}
}

private:
using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr;

Expand Down Expand Up @@ -266,7 +283,7 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

template <bool NoneIsLeaf, typename Span>
template <bool NoneIsLeaf, bool DictShouldBeSorted, typename Span>
bool FlattenIntoImpl(const py::handle &handle,
Span &leaves, // NOLINT[runtime/references]
const ssize_t &depth,
Expand All @@ -281,7 +298,11 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

template <bool NoneIsLeaf, typename LeafSpan, typename PathSpan, typename Stack>
template <bool NoneIsLeaf,
bool DictShouldBeSorted,
typename LeafSpan,
typename PathSpan,
typename Stack>
bool FlattenIntoWithPathImpl(const py::handle &handle,
LeafSpan &leaves, // NOLINT[runtime/references]
PathSpan &paths, // NOLINT[runtime/references]
Expand Down Expand Up @@ -329,6 +350,10 @@ class PyTreeSpec {
size_t operator()(const std::pair<const PyTreeSpec *, std::thread::id> &p) const;
};

// A set of namespaces that preserve the insertion order of the dictionary keys during
// flattening.
inline static std::unordered_set<std::string> sm_is_dict_insertion_ordered{};

// A set of (treespec, thread_id) pairs that are currently being represented as strings.
inline static std::unordered_set<std::pair<const PyTreeSpec *, std::thread::id>,
ThreadIndentTypeHash>
Expand All @@ -344,12 +369,13 @@ class PyTreeIter {
public:
PyTreeIter(const py::object &tree,
const std::optional<py::function> &leaf_predicate,
bool none_is_leaf,
std::string registry_namespace)
const bool &none_is_leaf,
const std::string &registry_namespace)
: m_agenda({std::make_pair(tree, 0)}),
m_leaf_predicate(leaf_predicate),
m_none_is_leaf(none_is_leaf),
m_namespace(std::move(registry_namespace)) {};
m_namespace(registry_namespace),
m_is_dict_insertion_ordered(PyTreeSpec::IsDictInsertionOrdered(registry_namespace)) {};

PyTreeIter() = delete;

Expand Down Expand Up @@ -377,6 +403,7 @@ class PyTreeIter {
std::optional<py::function> m_leaf_predicate;
bool m_none_is_leaf;
std::string m_namespace;
bool m_is_dict_insertion_ordered;

template <bool NoneIsLeaf>
[[nodiscard]] py::object NextImpl();
Expand Down
2 changes: 2 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ def unregister_node(
cls: type[CustomTreeNode[T]],
namespace: str = '',
) -> None: ...
def is_dict_insertion_ordered(namespace: str = '') -> bool: ...
def set_dict_insertion_ordered(mode: bool, namespace: str = '') -> None: ...
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from optree.registry import (
AttributeKeyPathEntry,
GetitemKeyPathEntry,
dict_insertion_ordered,
register_keypaths,
register_pytree_node,
register_pytree_node_class,
Expand Down Expand Up @@ -200,6 +201,7 @@
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
# Typing
'PyTreeSpec',
'PyTreeDef',
Expand Down
89 changes: 89 additions & 0 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import contextlib
import dataclasses
import functools
import inspect
Expand All @@ -29,6 +30,7 @@
Any,
Callable,
ClassVar,
Generator,
Iterable,
NamedTuple,
Sequence,
Expand Down Expand Up @@ -70,6 +72,7 @@
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
]


Expand Down Expand Up @@ -491,6 +494,17 @@ def _dict_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]:
return dict(safe_zip(keys, values))


def _dict_insertion_ordered_flatten(
dct: dict[KT, VT],
) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
keys, values = unzip2(dct.items())
return values, list(keys), keys


def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]:
return dict(safe_zip(keys, values))


def _ordereddict_flatten(
dct: OrderedDict[KT, VT],
) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
Expand All @@ -517,6 +531,21 @@ def _defaultdict_unflatten(
return defaultdict(default_factory, _dict_unflatten(keys, values))


def _defaultdict_insertion_ordered_flatten(
dct: defaultdict[KT, VT],
) -> tuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]:
values, keys, entries = _dict_insertion_ordered_flatten(dct)
return values, (dct.default_factory, keys), entries


def _defaultdict_insertion_ordered_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
) -> defaultdict[KT, VT]:
default_factory, keys = metadata
return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values))


def _deque_flatten(deq: deque[T]) -> tuple[deque[T], int | None]:
return deq, deq.maxlen

Expand Down Expand Up @@ -566,6 +595,23 @@ def _pytree_node_registry_get(
handler = _NODETYPE_REGISTRY.get((namespace, cls))
if handler is not None:
return handler

if _C.is_dict_insertion_ordered(namespace):
if cls is dict:
return PyTreeNodeRegistryEntry(
dict,
_dict_insertion_ordered_flatten, # type: ignore[arg-type]
_dict_insertion_ordered_unflatten, # type: ignore[arg-type]
path_entry_type=MappingEntry,
)
if cls is defaultdict:
return PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_insertion_ordered_flatten, # type: ignore[arg-type]
_defaultdict_insertion_ordered_unflatten, # type: ignore[arg-type]
path_entry_type=MappingEntry,
)

handler = _NODETYPE_REGISTRY.get(cls)
if handler is not None:
return handler
Expand All @@ -580,6 +626,49 @@ def _pytree_node_registry_get(
del _pytree_node_registry_get


@contextlib.contextmanager
def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None, None, None]:
"""Context manager to temporarily set the dictionary sorting mode.
This context manager is used to temporarily set the dictionary sorting mode for a specific
namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
should be sorted or keeping the insertion order when flattening a pytree.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE
... tree_flatten(tree, namespace='some-namespace')
(
[2, 3, 4, 1, 5],
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *})
)
Args:
mode (bool): The dictionary sorting mode to set.
namespace (str): The namespace to set the dictionary sorting mode for.
"""
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if namespace is __GLOBAL_NAMESPACE:
namespace = ''

with __REGISTRY_LOCK:
prev = _C.is_dict_insertion_ordered(namespace)
_C.set_dict_insertion_ordered(mode, namespace)

try:
yield
finally:
with __REGISTRY_LOCK:
_C.set_dict_insertion_ordered(prev, namespace)


####################################################################################################

warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=True)
Expand Down
9 changes: 9 additions & 0 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
"Unregister a Python type.",
py::arg("cls"),
py::arg("namespace") = "")
.def("is_dict_insertion_ordered",
&PyTreeSpec::IsDictInsertionOrdered,
"Return whether need to preserve the dict insertion order during flattening.",
py::arg("namespace") = "")
.def("set_dict_insertion_ordered",
&PyTreeSpec::SetDictInsertionOrdered,
"Set whether need to preserve the dict insertion order during flattening.",
py::arg("mode"),
py::arg("namespace") = "")
.def("flatten",
&PyTreeSpec::Flatten,
"Flattens a pytree.",
Expand Down
4 changes: 3 additions & 1 deletion src/treespec/constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ template <bool NoneIsLeaf>
py::list keys = DictKeys(dict);
if (node.kind != PyTreeKind::OrderedDict) [[likely]] {
node.original_keys = py::getattr(keys, Py_Get_ID(copy))();
TotalOrderSort(keys);
if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] {
TotalOrderSort(keys);
}
}
for (const py::handle& key : keys) {
children.emplace_back(dict[key]);
Expand Down
Loading

0 comments on commit 5cf5c3b

Please sign in to comment.