Skip to content

Commit

Permalink
feat: allow tree-map with mixed inputs of ordered and unordered dicti…
Browse files Browse the repository at this point in the history
…onaries (#42)
  • Loading branch information
XuehaiPan authored Mar 11, 2023
1 parent f608bfb commit fc2b0a1
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 124 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Allow tree-map with mixed inputs of ordered and unordered dictionaries by [@XuehaiPan](https://github.com/XuehaiPan) in [#42](https://github.com/metaopt/optree/pull/42).
- Add more utility functions for `namedtuple` and `PyStructSequence` type by [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/optree/pull/41).
- Add methods `PyTreeSpec.is_prefix` and `PyTreeSpec.is_suffix` and function `tree_broadcast_prefix` by [@XuehaiPan](https://github.com/XuehaiPan) in [#40](https://github.com/metaopt/optree/pull/40).
- Add tree reduce functions `tree_sum`, `tree_max`, and `tree_min` by [@XuehaiPan](https://github.com/XuehaiPan) in [#39](https://github.com/metaopt/optree/pull/39).
Expand All @@ -22,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Allow tree-map with mixed inputs of ordered and unordered dictionaries by [@XuehaiPan](https://github.com/XuehaiPan) in [#42](https://github.com/metaopt/optree/pull/42).
- Use more appropriate exception handling (e.g., change `ValueError` to `TypeError` in `structseq_fields`) by [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/optree/pull/41).
- Inherit `optree._C.InternalError` from `SystemError` rather than `RuntimeError` by [@XuehaiPan](https://github.com/XuehaiPan) in [#41](https://github.com/metaopt/optree/pull/41).
- Change keyword argument `initial` to `initializer` for `tree_reduce` to align with `functools.reduce` by [@XuehaiPan](https://github.com/XuehaiPan) in [#39](https://github.com/metaopt/optree/pull/39).
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ project(optree LANGUAGES CXX)

include(FetchContent)
set(PYBIND11_VERSION v2.10.3)
set(ABSEIL_CPP_VERSION 20220623.1)
set(ABSEIL_CPP_VERSION 20230125.1)
set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party")

if(NOT CMAKE_BUILD_TYPE)
Expand Down
3 changes: 3 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ fmt
CustomTreeNode
ForwardRef
Deque
deque
maxlen
TensorTree
unhashable
xys
Expand All @@ -62,3 +64,4 @@ isinstance
initializer
CPython
CPython's
sortable
11 changes: 11 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,17 @@ inline void AssertExactDefaultDict(const py::handle& object) {
}
}

inline void AssertExactStandardDict(const py::handle& object) {
if (!(PyDict_CheckExact(object.ptr()) || object.get_type().is(PyOrderedDictTypeObject) ||
object.get_type().is(PyDefaultDictTypeObject))) [[unlikely]] {
throw py::value_error(
absl::StrFormat("Expected an instance of "
"dict, collections.OrderedDict, or collections.defaultdict, "
"got %s.",
py::repr(object)));
}
}

inline void AssertExactDeque(const py::handle& object) {
if (!object.get_type().is(PyDequeTypeObject)) [[unlikely]] {
throw py::value_error(absl::StrFormat("Expected an instance of collections.deque, got %s.",
Expand Down
71 changes: 65 additions & 6 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import difflib
import functools
import textwrap
from collections import deque
from collections import OrderedDict, defaultdict, deque
from typing import Any, Callable, cast, overload

from optree import _C
Expand Down Expand Up @@ -1538,6 +1538,9 @@ def prefix_errors(
)


STANDARD_DICT_TYPES = frozenset([dict, OrderedDict, defaultdict])


# pylint: disable-next=too-many-locals
def _prefix_error(
key_path: KeyPath,
Expand All @@ -1555,7 +1558,17 @@ def _prefix_error(
return

# The subtrees may disagree because their roots are of different types:
if type(prefix_tree) is not type(full_tree):
prefix_tree_type = type(prefix_tree)
full_tree_type = type(full_tree)
both_standard_dict = (
prefix_tree_type in STANDARD_DICT_TYPES # type: ignore[comparison-overlap]
and full_tree_type in STANDARD_DICT_TYPES # type: ignore[comparison-overlap]
)
both_deque = prefix_tree_type is deque and full_tree_type is deque # type: ignore[comparison-overlap]
if prefix_tree_type is not full_tree_type and (
# Special handling for directory types
not both_standard_dict
):
yield lambda name: ValueError(
f'pytree structure error: different types at key path\n'
f' {{name}}{key_path.pprint()}\n'
Expand All @@ -1575,20 +1588,63 @@ def _prefix_error(
full_tree_children, full_tree_metadata, _ = flatten_one_level(
full_tree, none_is_leaf=none_is_leaf, namespace=namespace
)
# Special handling for directory types
if both_standard_dict:
prefix_tree_keys = (
prefix_tree_metadata
if prefix_tree_type is not defaultdict # type: ignore[comparison-overlap]
else prefix_tree_metadata[1] # type: ignore[index]
)
full_tree_keys = (
full_tree_metadata
if full_tree_type is not defaultdict # type: ignore[comparison-overlap]
else full_tree_metadata[1] # type: ignore[index]
)
prefix_tree_keys_set = set(prefix_tree_keys)
full_tree_keys_set = set(full_tree_keys)
if prefix_tree_keys_set != full_tree_keys_set:
missing_keys = sorted(prefix_tree_keys_set.difference(full_tree_keys_set))
extra_keys = sorted(full_tree_keys_set.difference(prefix_tree_keys_set))
key_difference = ''
if missing_keys:
key_difference += f'\nmissing key(s):\n {missing_keys}'
if extra_keys:
key_difference += f'\nextra key(s):\n {extra_keys}'
yield lambda name: ValueError(
f'pytree structure error: different pytree keys at key path\n'
f' {{name}}{key_path.pprint()}\n'
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
f' {prefix_tree_type}\n'
f'with {len(prefix_tree_keys)} key(s)\n'
f' {prefix_tree_keys}\n'
f'but at the same key path the full pytree has a subtree of type\n'
f' {full_tree_type}\n'
f'but with {len(full_tree_keys)} key(s)\n'
f' {full_tree_keys}{key_difference}'.format(name=name)
)
return # don't look for more errors in this subtree

if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
f'pytree structure error: different numbers of pytree children at key path\n'
f' {{name}}{key_path.pprint()}\n'
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
f' {type(prefix_tree)}\n'
f' {prefix_tree_type}\n'
f'with {len(prefix_tree_children)} children, '
f'but at the same key path the full pytree has a subtree of the same '
f'type but with {len(full_tree_children)} children.'.format(name=name)
)
return # don't look for more errors in this subtree

# Or they may disagree if their roots have different pytree metadata:
if prefix_tree_metadata != full_tree_metadata:
if (
prefix_tree_metadata != full_tree_metadata
and (not both_deque) # ignore maxlen mismatch for deque
and (
# Special handling for directory types already done in the keys check above
not both_standard_dict
)
):
prefix_tree_metadata_repr = repr(prefix_tree_metadata)
full_tree_metadata_repr = repr(full_tree_metadata)
metadata_diff = textwrap.indent(
Expand All @@ -1604,7 +1660,7 @@ def _prefix_error(
f'pytree structure error: different pytree metadata at key path\n'
f' {{name}}{key_path.pprint()}\n'
f'At that key path, the prefix pytree {{name}} has a subtree of type\n'
f' {type(prefix_tree)}\n'
f' {prefix_tree_type}\n'
f'with metadata\n'
f' {prefix_tree_metadata_repr}\n'
f'but at the same key path the full pytree has a subtree of the same '
Expand All @@ -1619,7 +1675,10 @@ def _prefix_error(
# so recurse:
keys = _child_keys(prefix_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
keys_ = _child_keys(full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) # type: ignore[arg-type]
assert keys == keys_, f'equal pytree nodes gave different keys: {keys} and {keys_}'
assert keys == keys_ or (
# Special handling for directory types already done in the keys check above
both_standard_dict
), f'equal pytree nodes gave different keys: {keys} and {keys_}'
# pylint: disable-next=invalid-name
for k, t1, t2 in zip(keys, prefix_tree_children, full_tree_children):
yield from _prefix_error(
Expand Down
48 changes: 12 additions & 36 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from optree import _C
from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, PyTree, T, UnflattenFunc
from optree.utils import safe_zip, unzip2
from optree.utils import safe_zip, total_order_sorted, unzip2


__all__ = [
Expand Down Expand Up @@ -277,38 +277,12 @@ def tree_unflatten(cls, metadata, children):
return cls


def _sorted_items(items: Iterable[tuple[KT, VT]]) -> Iterable[tuple[KT, VT]]: # pragma: no cover
try:
# Sort directly if possible (do not use `key` for performance reasons)
return sorted(items)
except TypeError: # the keys are not comparable
try:
# Add `{obj.__class__.__module__}.{obj.__class__.__qualname__}` to the key order to make
# it sortable between different types (e.g. `int` vs. `str`)
return sorted(
items,
# pylint: disable-next=consider-using-f-string
key=lambda kv: ('{0.__module__}.{0.__qualname__}'.format(kv[0].__class__), kv),
)
except TypeError: # cannot sort the keys (e.g. user-defined types)
return items # fallback to insertion order


def _sorted_keys(dct: dict[KT, VT]) -> Iterable[KT]: # pragma: no cover
try:
# Sort directly if possible (do not use `key` for performance reasons)
return sorted(dct) # type: ignore[type-var]
except TypeError: # the keys are not comparable
try:
# Add `{obj.__class__.__module__}.{obj.__class__.__qualname__}` to the key order to make
# it sortable between different types (e.g. `int` vs. `str`)
return sorted(
dct,
# pylint: disable-next=consider-using-f-string
key=lambda o: ('{0.__module__}.{0.__qualname__}'.format(o.__class__), o),
)
except TypeError: # cannot sort the keys (e.g. user-defined types)
return dct # fallback to insertion order
def _sorted_items(items: Iterable[tuple[KT, VT]]) -> list[tuple[KT, VT]]: # pragma: no cover
return total_order_sorted(items, key=lambda kv: kv[0])


def _sorted_keys(dct: dict[KT, VT]) -> list[KT]: # pragma: no cover
return total_order_sorted(dct)


def _dict_flatten(dct: dict[KT, VT]) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
Expand Down Expand Up @@ -365,12 +339,14 @@ def _defaultdict_flatten(


def _pytree_node_registry_get(
type: type, *, namespace: str = __GLOBAL_NAMESPACE
cls: type,
*,
namespace: str = __GLOBAL_NAMESPACE,
) -> PyTreeNodeRegistryEntry | None:
entry: PyTreeNodeRegistryEntry | None = _nodetype_registry.get(type)
entry: PyTreeNodeRegistryEntry | None = _nodetype_registry.get(cls)
if entry is not None or namespace is __GLOBAL_NAMESPACE or namespace == '':
return entry
return _nodetype_registry.get((namespace, type))
return _nodetype_registry.get((namespace, cls))


register_pytree_node.get = _pytree_node_registry_get # type: ignore[attr-defined]
Expand Down
46 changes: 43 additions & 3 deletions optree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,50 @@

from __future__ import annotations

from typing import Any, Iterable, Sequence
from typing import Any, Callable, Iterable, Sequence

from optree.typing import T, U

def safe_zip(*args: Sequence[Any]) -> list[tuple[Any, ...]]:

def total_order_sorted(
iterable: Iterable[T],
*,
key: Callable[[T], Any] | None = None,
reverse: bool = False,
) -> list[T]:
"""Sort an iterable in a total order.
This is useful for sorting objects that are not comparable, e.g., dictionaries with different
types of keys.
"""
sequence = list(iterable)

try:
# Sort directly if possible
return sorted(sequence, key=key, reverse=reverse) # type: ignore[type-var,arg-type]
except TypeError:
if key is None:

def key_fn(x: T) -> tuple[str, Any]:
# pylint: disable-next=consider-using-f-string
return ('{0.__module__}.{0.__qualname__}'.format(x.__class__), x)

else:

def key_fn(x: T) -> tuple[str, Any]:
y = key(x) # type: ignore[misc]
# pylint: disable-next=consider-using-f-string
return ('{0.__module__}.{0.__qualname__}'.format(y.__class__), y)

try:
# Add `{obj.__class__.__module__}.{obj.__class__.__qualname__}` to the key order to make
# it sortable between different types (e.g., `int` vs. `str`)
return sorted(sequence, key=key_fn, reverse=reverse)
except TypeError: # cannot sort the keys (e.g., user-defined types)
return sequence # fallback to original order


def safe_zip(*args: Sequence[T]) -> list[tuple[T, ...]]:
"""Strict zip that requires all arguments to be the same length."""
n = len(args[0])
for arg in args[1:]:
Expand All @@ -28,7 +68,7 @@ def safe_zip(*args: Sequence[Any]) -> list[tuple[Any, ...]]:
return list(zip(*args))


def unzip2(xys: Iterable[tuple[Any, Any]]) -> tuple[tuple[Any, ...], tuple[Any, ...]]:
def unzip2(xys: Iterable[tuple[T, U]]) -> tuple[tuple[T, ...], tuple[U, ...]]:
"""Unzip sequence of length-2 tuples into two tuples."""
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-2 output.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ select = [
"ISC", # flake8-implicit-str-concat
"PIE", # flake8-pie
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RSE", # flake8-raise
"RET", # flake8-return
"SIM", # flake8-simplify
Expand Down
Loading

0 comments on commit fc2b0a1

Please sign in to comment.