Skip to content

Commit

Permalink
chore(treespec): update string representation for OrderedDict (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Apr 4, 2024
1 parent ad0279b commit 516f99e
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 107 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1

- name: Set up Python 3.10
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.12"
update-environment: true

- name: Upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Update string representation for `OrderedDict` by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/optree/pull/133).

### Fixed

Expand Down
38 changes: 19 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -433,39 +433,39 @@ While flattening a tree, it will remain in the tree structure definitions rather
```

OpTree provides a keyword argument `none_is_leaf` to determine whether to consider the `None` object as a leaf, like other opaque objects.
If `none_is_leaf=True`, the `None` object will place in the leaves list.
If `none_is_leaf=True`, the `None` object will be placed in the leaves list.
Otherwise, the `None` object will remain in the tree specification (structure).

```python
>>> import torch

>>> linear = torch.nn.Linear(in_features=3, out_features=2, bias=False)
>>> linear._parameters # a container has None
OrderedDict([
('weight', Parameter containing:
tensor([[-0.6677, 0.5209, 0.3295],
[-0.4876, -0.3142, 0.1785]], requires_grad=True)),
('bias', None)
])
OrderedDict({
'weight': Parameter containing:
tensor([[-0.6677, 0.5209, 0.3295],
[-0.4876, -0.3142, 0.1785]], requires_grad=True),
'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', None)
])
OrderedDict({
'weight': tensor([[0., 0., 0.],
[0., 0., 0.]]),
'bias': None
})

>>> optree.tree_map(torch.zeros_like, linear._parameters, none_is_leaf=True)
Traceback (most recent call last):
...
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not NoneType

>>> optree.tree_map(lambda t: torch.zeros_like(t) if t is not None else 0, linear._parameters, none_is_leaf=True)
OrderedDict([
('weight', tensor([[0., 0., 0.],
[0., 0., 0.]])),
('bias', 0)
])
OrderedDict({
'weight': tensor([[0., 0., 0.],
[0., 0., 0.]]),
'bias': 0
})
```

### Key Ordering for Dictionaries
Expand All @@ -489,9 +489,9 @@ If users want to keep the values in the insertion order in pytree traversal, the
>>> OrderedDict([('a', [1, 2]), ('b', [3])]) == OrderedDict([('b', [3]), ('a', [1, 2])])
False
>>> optree.tree_flatten(OrderedDict([('a', [1, 2]), ('b', [3])]))
([1, 2, 3], PyTreeSpec(OrderedDict([('a', [*, *]), ('b', [*])])))
([1, 2, 3], PyTreeSpec(OrderedDict({'a': [*, *], 'b': [*]})))
>>> optree.tree_flatten(OrderedDict([('b', [3]), ('a', [1, 2])]))
([3, 1, 2], PyTreeSpec(OrderedDict([('b', [*]), ('a', [*, *])])))
([3, 1, 2], PyTreeSpec(OrderedDict({'b': [*], 'a': [*, *]})))
```

**Since OpTree v0.9.0, the key order of the reconstructed output dictionaries from `tree_unflatten` is guaranteed to be consistent with the key order of the input dictionaries in `tree_flatten`.**
Expand Down
36 changes: 24 additions & 12 deletions optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,33 @@ def tree_ravel(
... 'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,))
... },
... }
>>> tree
{'layer1': {'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)},
'layer2': {'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)}}
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)
},
'layer2': {
'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
Array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32)
>>> unravel_func(flat)
{'layer1': {'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)},
'layer2': {'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)}}
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([6., 7.], dtype=float32)
},
'layer2': {
'weight': Array([[8., 9.]], dtype=float32),
'bias': Array([10.], dtype=float32)
}
}
Args:
tree (pytree): a pytree of arrays and scalars to ravel.
Expand Down
36 changes: 24 additions & 12 deletions optree/integration/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,33 @@ def tree_ravel(
... 'bias': np.arange(10, 11, dtype=np.float32).reshape((1,))
... },
... }
>>> tree
{'layer1': {'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)},
'layer2': {'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)}}
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)
},
'layer2': {
'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32)
>>> unravel_func(flat)
{'layer1': {'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)},
'layer2': {'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)}}
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([6., 7.], dtype=float32)
},
'layer2': {
'weight': array([[8., 9.]], dtype=float32),
'bias': array([10.], dtype=float32)
}
}
Args:
tree (pytree): a pytree of arrays and scalars to ravel.
Expand Down
36 changes: 24 additions & 12 deletions optree/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,33 @@ def tree_ravel(
... 'bias': torch.arange(10, 11, dtype=torch.float64).reshape((1,))
... },
... }
>>> tree
{'layer1': {'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)},
'layer2': {'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)}}
>>> tree # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)
},
'layer2': {
'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)
}
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
tensor([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=torch.float64)
>>> unravel_func(flat)
{'layer1': {'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)},
'layer2': {'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)}}
>>> unravel_func(flat) # doctest: +IGNORE_WHITESPACE
{
'layer1': {
'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([6., 7.], dtype=torch.float64)
},
'layer2': {
'weight': tensor([[8., 9.]], dtype=torch.float64),
'bias': tensor([10.], dtype=torch.float64)
}
}
Args:
tree (pytree): a pytree of tensors to ravel.
Expand Down
32 changes: 16 additions & 16 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ def tree_flatten(
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[2, 3, 4, 1, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}))
)
>>> tree_flatten(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
(
[2, 3, 4, 1, None, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf)
)
Args:
Expand Down Expand Up @@ -233,13 +233,13 @@ def tree_flatten_with_path(
(
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('d',)],
[2, 3, 4, 1, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}))
)
>>> tree_flatten_with_path(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE
(
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('c',), ('d',)],
[2, 3, 4, 1, None, 5],
PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf)
)
Args:
Expand Down Expand Up @@ -951,7 +951,7 @@ def tree_transpose_map_with_path(
>>> tree_transpose_map_with_path( # doctest: +IGNORE_WHITESPACE
... lambda p, x: {'path': p, 'value': x},
... tree,
... inner_treespec=tree_structure({'path': 0, 'value': 0})),
... inner_treespec=tree_structure({'path': 0, 'value': 0}),
... )
{
'path': {'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), 'a': ('a',), 'c': (('c', 0), ('c', 1))},
Expand Down Expand Up @@ -1694,7 +1694,7 @@ def tree_max(
>>> tree_max({})
Traceback (most recent call last):
...
ValueError: max() arg is an empty sequence
ValueError: max() iterable argument is empty
>>> tree_max({}, default=0)
0
>>> tree_max({'x': 0, 'y': (2, 1)})
Expand All @@ -1704,15 +1704,15 @@ def tree_max(
>>> tree_max({'a': None}) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: max() arg is an empty sequence
ValueError: max() iterable argument is empty
>>> tree_max({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default
0
>>> tree_max({'a': None}, none_is_leaf=True)
None
>>> tree_max(None) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: max() arg is an empty sequence
ValueError: max() iterable argument is empty
>>> tree_max(None, default=0)
0
>>> tree_max(None, none_is_leaf=True)
Expand Down Expand Up @@ -1789,7 +1789,7 @@ def tree_min(
>>> tree_min({})
Traceback (most recent call last):
...
ValueError: min() arg is an empty sequence
ValueError: min() iterable argument is empty
>>> tree_min({}, default=0)
0
>>> tree_min({'x': 0, 'y': (2, 1)})
Expand All @@ -1799,15 +1799,15 @@ def tree_min(
>>> tree_min({'a': None}) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: min() arg is an empty sequence
ValueError: min() iterable argument is empty
>>> tree_min({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default
0
>>> tree_min({'a': None}, none_is_leaf=True)
None
>>> tree_min(None) # `None` is a non-leaf node with arity 0 by default
Traceback (most recent call last):
...
ValueError: min() arg is an empty sequence
ValueError: min() iterable argument is empty
>>> tree_min(None, default=0)
0
>>> tree_min(None, none_is_leaf=True)
Expand Down Expand Up @@ -2455,15 +2455,15 @@ def treespec_ordereddict(
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': treespec_leaf()})
PyTreeSpec(OrderedDict([('a', *), ('b', *)]))
PyTreeSpec(OrderedDict({'a': *, 'b': *}))
>>> treespec_ordereddict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())])
PyTreeSpec(OrderedDict([('b', *), ('c', *), ('a', None)]))
PyTreeSpec(OrderedDict({'b': *, 'c': *, 'a': None}))
>>> treespec_ordereddict()
PyTreeSpec(OrderedDict([]))
PyTreeSpec(OrderedDict())
>>> treespec_ordereddict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()]))
PyTreeSpec(OrderedDict([('a', *), ('b', (*, *))]))
PyTreeSpec(OrderedDict({'a': *, 'b': (*, *)}))
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2])})
PyTreeSpec(OrderedDict([('a', *), ('b', [*, *])]))
PyTreeSpec(OrderedDict({'a': *, 'b': [*, *]}))
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)})
Traceback (most recent call last):
...
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ docs = [
]
benchmark = [
"jax[cpu] >= 0.4.6, < 0.5.0a0",
"torch >= 2.0, < 2.1.0a0",
"torch >= 2.0, < 2.3.0a0",
"torchvision",
"dm-tree >= 0.1, < 0.2.0a0",
"pandas",
Expand Down
Loading

0 comments on commit 516f99e

Please sign in to comment.