Skip to content

Commit

Permalink
feat(registry): add context manager to temporarily set the dictionary…
Browse files Browse the repository at this point in the history
… sorting mode (#147)
  • Loading branch information
XuehaiPan authored Jul 4, 2024
1 parent ddbd0e5 commit 9606d15
Show file tree
Hide file tree
Showing 20 changed files with 1,712 additions and 902 deletions.
64 changes: 57 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
if-no-files-found: error

build-wheels:
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }}
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} (${{ matrix.archs }})
runs-on: ${{ matrix.os }}
needs: [build-sdist]
if: github.repository == 'metaopt/optree' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
Expand All @@ -92,14 +92,65 @@ jobs:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version:
["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"]
archs: [
# Generic
"auto",
# Linux
"aarch64",
"ppc64le",
"s390x",
# Windows
"ARM64",
]
include:
- os: macos-13
python-version: "3.7"
archs: "auto"
exclude:
- os: ubuntu-latest
archs: "ARM64"
- os: windows-latest
archs: "aarch64"
- os: windows-latest
archs: "ppc64le"
- os: windows-latest
archs: "s390x"
- os: macos-latest
python-version: "3.7" # Python 3.7 does not support macOS ARM64
archs: "aarch64"
- os: macos-latest
archs: "ppc64le"
- os: macos-latest
archs: "s390x"
- os: macos-latest
archs: "ARM64"
- os: ubuntu-latest
python-version: "pypy3.9"
archs: "ppc64le"
- os: ubuntu-latest
python-version: "pypy3.10"
archs: "ppc64le"
- os: ubuntu-latest
python-version: "pypy3.9"
archs: "s390x"
- os: ubuntu-latest
python-version: "pypy3.10"
archs: "s390x"
- os: windows-latest
python-version: "3.7"
archs: "ARM64"
- os: windows-latest
python-version: "3.8"
archs: "ARM64"
- os: windows-latest
python-version: "pypy3.9"
archs: "ARM64"
- os: windows-latest
python-version: "pypy3.10"
archs: "ARM64"
- os: macos-latest
python-version: "3.7"
fail-fast: false
timeout-minutes: 60
timeout-minutes: 120
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down Expand Up @@ -139,17 +190,16 @@ jobs:
uses: pypa/cibuildwheel@v2.19
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
CIBW_ARCHS_LINUX: auto aarch64 ppc64le s390x
CIBW_ARCHS_WINDOWS: auto ARM64
CIBW_ARCHS_MACOS: x86_64 arm64 universal2
CIBW_ARCHS: ${{ matrix.archs }}
CIBW_ARCHS_MACOS: ${{ matrix.archs }} universal2
with:
package-dir: .
output-dir: wheelhouse
config-file: "{package}/pyproject.toml"

- uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.python-version }}-${{ matrix.os }}
name: wheels-${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.archs }}
path: wheelhouse/*.whl
if-no-files-found: error

Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add context manager to temporarily set the dictionary sorting mode by [@XuehaiPan](https://github.com/XuehaiPan) in [#147](https://github.com/metaopt/optree/pull/147).
- Add PyPy support by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/optree/pull/145).
- Add 32-bit wheels for Linux and Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#141](https://github.com/metaopt/optree/pull/141).
- Add Linux ppc64le and s390x wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/optree/pull/138).
Expand All @@ -21,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Use `stable` tag instead of 2.12.0 for `pybind11` version by [@XuehaiPan](https://github.com/XuehaiPan) in [#146](https://github.com/metaopt/optree/pull/146).
- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](https://github.com/XuehaiPan) in [#135](https://github.com/metaopt/optree/pull/135).
- Update minimal version of `typing-extensions` to 4.5.0 for `typing_extensions.deprecated` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134).
- Update string representation for `OrderedDict` by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/optree/pull/133).
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ pytest: pytest-install
$(PYTHON) -m pytest --version
cd tests && $(PYTHON) -X dev -c 'import $(PROJECT_PATH)' && \
$(PYTHON) -X dev -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \
$(PYTHON) -X dev -m pytest --verbose --color=yes \
$(PYTHON) -X dev -m pytest --verbose --color=yes --durations=0 --showlocals \
--cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \
$(PYTESTOPTS) .

Expand Down Expand Up @@ -152,7 +152,7 @@ mypy: mypy-install

xdoctest: xdoctest-install
$(PYTHON) -m xdoctest --version
$(PYTHON) -m xdoctest $(PROJECT_PATH)
$(PYTHON) -m xdoctest --global-exec "from optree import *" $(PROJECT_PATH)

doctest: xdoctest

Expand Down
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Tree Manipulation Functions

.. autosummary::

dict_insertion_ordered
tree_flatten
tree_flatten_with_path
tree_flatten_with_accessor
Expand Down Expand Up @@ -55,6 +56,7 @@ Tree Manipulation Functions
tree_flatten_one_level
prefix_errors

.. autofunction:: dict_insertion_ordered
.. autofunction:: tree_flatten
.. autofunction:: tree_flatten_with_path
.. autofunction:: tree_flatten_with_accessor
Expand Down
39 changes: 34 additions & 5 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,25 @@ class PyTreeSpec {
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Check if should preserve the insertion order of the dictionary keys during flattening.
inline static bool IsDictInsertionOrdered(const std::string &registry_namespace,
const bool &inherit_global_namespace = true) {
return (sm_is_dict_insertion_ordered.find(registry_namespace) !=
sm_is_dict_insertion_ordered.end()) ||
(inherit_global_namespace &&
sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end());
}

// Set the namespace to preserve the insertion order of the dictionary keys during flattening.
inline static void SetDictInsertionOrdered(const bool &mode,
const std::string &registry_namespace) {
if (mode) [[likely]] {
sm_is_dict_insertion_ordered.insert(registry_namespace);
} else [[unlikely]] {
sm_is_dict_insertion_ordered.erase(registry_namespace);
}
}

private:
using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr;

Expand Down Expand Up @@ -266,7 +285,7 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

template <bool NoneIsLeaf, typename Span>
template <bool NoneIsLeaf, bool DictShouldBeSorted, typename Span>
bool FlattenIntoImpl(const py::handle &handle,
Span &leaves, // NOLINT[runtime/references]
const ssize_t &depth,
Expand All @@ -281,7 +300,11 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

template <bool NoneIsLeaf, typename LeafSpan, typename PathSpan, typename Stack>
template <bool NoneIsLeaf,
bool DictShouldBeSorted,
typename LeafSpan,
typename PathSpan,
typename Stack>
bool FlattenIntoWithPathImpl(const py::handle &handle,
LeafSpan &leaves, // NOLINT[runtime/references]
PathSpan &paths, // NOLINT[runtime/references]
Expand Down Expand Up @@ -329,6 +352,10 @@ class PyTreeSpec {
size_t operator()(const std::pair<const PyTreeSpec *, std::thread::id> &p) const;
};

// A set of namespaces that preserve the insertion order of the dictionary keys during
// flattening.
inline static std::unordered_set<std::string> sm_is_dict_insertion_ordered{};

// A set of (treespec, thread_id) pairs that are currently being represented as strings.
inline static std::unordered_set<std::pair<const PyTreeSpec *, std::thread::id>,
ThreadIndentTypeHash>
Expand All @@ -344,12 +371,13 @@ class PyTreeIter {
public:
PyTreeIter(const py::object &tree,
const std::optional<py::function> &leaf_predicate,
bool none_is_leaf,
std::string registry_namespace)
const bool &none_is_leaf,
const std::string &registry_namespace)
: m_agenda({std::make_pair(tree, 0)}),
m_leaf_predicate(leaf_predicate),
m_none_is_leaf(none_is_leaf),
m_namespace(std::move(registry_namespace)) {};
m_namespace(registry_namespace),
m_is_dict_insertion_ordered(PyTreeSpec::IsDictInsertionOrdered(registry_namespace)) {};

PyTreeIter() = delete;

Expand Down Expand Up @@ -377,6 +405,7 @@ class PyTreeIter {
std::optional<py::function> m_leaf_predicate;
bool m_none_is_leaf;
std::string m_namespace;
bool m_is_dict_insertion_ordered;

template <bool NoneIsLeaf>
[[nodiscard]] py::object NextImpl();
Expand Down
8 changes: 8 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,11 @@ def unregister_node(
cls: type[CustomTreeNode[T]],
namespace: str = '',
) -> None: ...
def is_dict_insertion_ordered(
namespace: str = '',
inherit_global_namespace: bool = True,
) -> bool: ...
def set_dict_insertion_ordered(
mode: bool,
namespace: str = '',
) -> None: ...
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from optree.registry import (
AttributeKeyPathEntry,
GetitemKeyPathEntry,
dict_insertion_ordered,
register_keypaths,
register_pytree_node,
register_pytree_node_class,
Expand Down Expand Up @@ -200,6 +201,7 @@
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
# Typing
'PyTreeSpec',
'PyTreeDef',
Expand Down
89 changes: 89 additions & 0 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import contextlib
import dataclasses
import functools
import inspect
Expand All @@ -29,6 +30,7 @@
Any,
Callable,
ClassVar,
Generator,
Iterable,
NamedTuple,
Sequence,
Expand Down Expand Up @@ -70,6 +72,7 @@
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
]


Expand Down Expand Up @@ -491,6 +494,17 @@ def _dict_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]:
return dict(safe_zip(keys, values))


def _dict_insertion_ordered_flatten(
dct: dict[KT, VT],
) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
keys, values = unzip2(dct.items())
return values, list(keys), keys


def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]:
return dict(safe_zip(keys, values))


def _ordereddict_flatten(
dct: OrderedDict[KT, VT],
) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
Expand All @@ -517,6 +531,21 @@ def _defaultdict_unflatten(
return defaultdict(default_factory, _dict_unflatten(keys, values))


def _defaultdict_insertion_ordered_flatten(
dct: defaultdict[KT, VT],
) -> tuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]:
values, keys, entries = _dict_insertion_ordered_flatten(dct)
return values, (dct.default_factory, keys), entries


def _defaultdict_insertion_ordered_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
) -> defaultdict[KT, VT]:
default_factory, keys = metadata
return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values))


def _deque_flatten(deq: deque[T]) -> tuple[deque[T], int | None]:
return deq, deq.maxlen

Expand Down Expand Up @@ -566,6 +595,23 @@ def _pytree_node_registry_get(
handler = _NODETYPE_REGISTRY.get((namespace, cls))
if handler is not None:
return handler

if _C.is_dict_insertion_ordered(namespace):
if cls is dict:
return PyTreeNodeRegistryEntry(
dict,
_dict_insertion_ordered_flatten, # type: ignore[arg-type]
_dict_insertion_ordered_unflatten, # type: ignore[arg-type]
path_entry_type=MappingEntry,
)
if cls is defaultdict:
return PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_insertion_ordered_flatten, # type: ignore[arg-type]
_defaultdict_insertion_ordered_unflatten, # type: ignore[arg-type]
path_entry_type=MappingEntry,
)

handler = _NODETYPE_REGISTRY.get(cls)
if handler is not None:
return handler
Expand All @@ -580,6 +626,49 @@ def _pytree_node_registry_get(
del _pytree_node_registry_get


@contextlib.contextmanager
def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None, None, None]:
"""Context manager to temporarily set the dictionary sorting mode.
This context manager is used to temporarily set the dictionary sorting mode for a specific
namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
should be sorted or keeping the insertion order when flattening a pytree.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE
... tree_flatten(tree, namespace='some-namespace')
(
[2, 3, 4, 1, 5],
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
)
Args:
mode (bool): The dictionary sorting mode to set.
namespace (str): The namespace to set the dictionary sorting mode for.
"""
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if namespace is __GLOBAL_NAMESPACE:
namespace = ''

with __REGISTRY_LOCK:
prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False)
_C.set_dict_insertion_ordered(bool(mode), namespace)

try:
yield
finally:
with __REGISTRY_LOCK:
_C.set_dict_insertion_ordered(prev, namespace)


####################################################################################################

warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=True)
Expand Down
Loading

0 comments on commit 9606d15

Please sign in to comment.