Skip to content

Commit

Permalink
test: update tests for braodcast
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 6, 2023
1 parent a23aa76 commit 4180be3
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 24 deletions.
204 changes: 182 additions & 22 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,16 @@ def add_leaves(p, x):
assert leaves == [1, 2, 3, None, 4]


def test_tree_replace_nones():
sentinel = object()
assert optree.tree_replace_nones(sentinel, {'a': 1, 'b': None, 'c': (2, None)}) == {
'a': 1,
'b': sentinel,
'c': (2, sentinel),
}
assert optree.tree_replace_nones(sentinel, None) == sentinel


@parametrize(
tree=TREES,
none_is_leaf=[False, True],
Expand Down Expand Up @@ -931,36 +941,39 @@ class MyExtraDict(MyAnotherDict):


def test_tree_broadcast_prefix():
assert optree.tree_broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1]
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
assert optree.tree_broadcast_prefix(1, [2, 3, 4]) == [1, 1, 1]
assert optree.tree_broadcast_prefix([1, 2, 3], [4, 5, 6]) == [1, 2, 3]
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'),
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].'),
):
optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, (3, 3)]
assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) == [
optree.tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
assert optree.tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) == [1, 2, (3, 3)]
assert optree.tree_broadcast_prefix(
[1, 2, 3],
[4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}],
) == [
1,
2,
{'a': 3, 'b': 3, 'c': (None, 3)},
]
assert optree.tree_broadcast_prefix(
[1, 2, 3],
[1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}],
[4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}],
none_is_leaf=True,
) == [1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}]


def test_broadcast_prefix():
assert optree.broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1]
assert optree.broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3]
assert optree.broadcast_prefix(1, [2, 3, 4]) == [1, 1, 1]
assert optree.broadcast_prefix([1, 2, 3], [4, 5, 6]) == [1, 2, 3]
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'),
match=re.escape('list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].'),
):
optree.broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
assert optree.broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, 3, 3]
assert optree.broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) == [
optree.broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
assert optree.broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) == [1, 2, 3, 3]
assert optree.broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) == [
1,
2,
3,
Expand All @@ -969,19 +982,166 @@ def test_broadcast_prefix():
]
assert optree.broadcast_prefix(
[1, 2, 3],
[1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}],
[4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}],
none_is_leaf=True,
) == [1, 2, 3, 3, 3, 3]


def test_tree_replace_nones():
sentinel = object()
assert optree.tree_replace_nones(sentinel, {'a': 1, 'b': None, 'c': (2, None)}) == {
'a': 1,
'b': sentinel,
'c': (2, sentinel),
}
assert optree.tree_replace_nones(sentinel, None) == sentinel
def test_tree_broadcast_common():
assert optree.tree_broadcast_common(1, [2, 3, 4]) == ([1, 1, 1], [2, 3, 4])
assert optree.tree_broadcast_common([1, 2, 3], [4, 5, 6]) == ([1, 2, 3], [4, 5, 6])
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4.'),
):
optree.tree_broadcast_common([1, 2, 3], [1, 2, 3, 4])
assert optree.tree_broadcast_common([1, 2, 3], [4, 5, (6, 7)]) == (
[1, 2, (3, 3)],
[4, 5, (6, 7)],
)
assert optree.tree_broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) == (
[1, (2, 3), (4, 4)],
[5, (6, 6), (7, 8)],
)
assert optree.tree_broadcast_common(
[1, (2, 3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, (2, 3), {'a': 4, 'b': 4, 'c': (None, 4)}],
[5, (6, 6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (None, 4)}],
[5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (4, 4)}],
[5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (4, 4)}],
[5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
) == (
[(1, 1), OrderedDict(b=4, c=(None, 5), a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': (8, 8), 'b': 9}],
)
assert optree.tree_broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
none_is_leaf=True,
) == (
[(1, 1), OrderedDict(b=4, c=(5, 5), a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': (8, 8), 'b': 9}],
)
assert optree.tree_broadcast_common(
[1, {'c': (None, 4), 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
) == (
[(1, 1), {'c': (None, 4), 'b': 3, 'a': (2, 2)}],
[(5, 6), OrderedDict(b=9, c=(None, 0), a=(7, 8))],
)
assert optree.tree_broadcast_common(
[1, {'b': 3, 'a': 2, 'c': (None, 4)}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
none_is_leaf=True,
) == (
[(1, 1), {'c': (None, 4), 'b': 3, 'a': (2, 2)}],
[(5, 6), OrderedDict(b=9, c=(0, 0), a=(7, 8))],
)


def test_broadcast_common():
assert optree.broadcast_common(1, [2, 3, 4]) == ([1, 1, 1], [2, 3, 4])
assert optree.broadcast_common([1, 2, 3], [4, 5, 6]) == ([1, 2, 3], [4, 5, 6])
with pytest.raises(
ValueError,
match=re.escape('list arity mismatch; expected: 3, got: 4.'),
):
optree.broadcast_common([1, 2, 3], [1, 2, 3, 4])
assert optree.broadcast_common([1, 2, 3], [4, 5, (6, 7)]) == (
[1, 2, 3, 3],
[4, 5, 6, 7],
)
assert optree.broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) == (
[1, 2, 3, 4, 4],
[5, 6, 6, 7, 8],
)
assert optree.broadcast_common(
[1, (2, 3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, 2, 3, 4, 4, 4],
[5, 6, 6, 7, 8, 9],
)
assert optree.broadcast_common(
[1, (2, 3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, 2, 3, 4, 4, 4, 4],
[5, 6, 6, 7, 8, None, 9],
)
assert optree.broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
) == (
[1, 2, 3, 4, 4, 4],
[5, 6, 6, 7, 8, 9],
)
assert optree.broadcast_common(
[1, OrderedDict(foo=2, bar=3), 4],
[5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}],
none_is_leaf=True,
) == (
[1, 2, 3, 4, 4, 4, 4],
[5, 6, 6, 7, 8, None, 9],
)

assert optree.broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
) == (
[1, 1, 4, 5, 2, 3],
[6, 7, 9, 0, 8, 8],
)
assert optree.broadcast_common(
[1, OrderedDict(b=4, c=5, a=(2, 3))],
[(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}],
none_is_leaf=True,
) == (
[1, 1, 4, 5, 5, 2, 3],
[6, 7, 9, None, 0, 8, 8],
)
assert optree.broadcast_common(
[1, {'c': (None, 4), 'b': 3, 'a': 2}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
) == (
[1, 1, 2, 2, 3, 4],
[5, 6, 7, 8, 9, 0],
)
assert optree.broadcast_common(
[1, {'b': 3, 'a': 2, 'c': (None, 4)}],
[(5, 6), OrderedDict(b=9, c=0, a=(7, 8))],
none_is_leaf=True,
) == (
[1, 1, 2, 2, 3, None, 4],
[5, 6, 7, 8, 9, 0, 0],
)


def test_tree_reduce():
Expand Down
32 changes: 32 additions & 0 deletions tests/test_prefix_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,20 @@ def test_different_types():
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs)
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = {'a': 1, 'b': 2}, defaultdict(int, {'a': 1, 'b': [2, 3]})
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs)
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)


Expand Down Expand Up @@ -266,30 +274,50 @@ def test_different_metadata():
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore key ordering
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = defaultdict(list, {'a': 1, 'b': 2}), defaultdict(set, {'b': [4, 5], 'a': 3})
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore default factory
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = {'a': 1, 'b': 2}, defaultdict(list, {'b': [4, 5], 'a': 3})
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore dictionary types
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = OrderedDict({'b': 5, 'a': 4}), {'a': 1, 'b': [2, 3]}
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore dictionary types
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = deque([1, 2], maxlen=None), deque([3, [4, 5]], maxlen=3)
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore maxlen
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)

lhs, rhs = FlatCache([None, 1]), FlatCache(1)
Expand Down Expand Up @@ -432,6 +460,10 @@ def test_no_errors():
optree.tree_map_(lambda x, y: None, lhs, rhs)
lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs)
assert lhs_treespec.is_prefix(rhs_treespec)
assert (
len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves)))
== rhs_treespec.num_leaves
)
() = optree.prefix_errors(lhs, rhs)


Expand Down
6 changes: 4 additions & 2 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_treespec_compose_children(tree, inner_tree, none_is_leaf, namespace):
inner_treespec.num_nodes * treespec.num_leaves
)
assert composed_treespec.num_nodes == expected_nodes
leaves = [1] * expected_leaves
leaves = list(range(expected_leaves))
composed = optree.tree_unflatten(composed_treespec, leaves)
assert leaves == optree.tree_leaves(
composed,
Expand Down Expand Up @@ -569,7 +569,9 @@ def test_treespec_num_nodes(tree, none_is_leaf, namespace):
while stack:
spec = stack.pop()
nodes.append(spec)
stack.extend(reversed(spec.children()))
children = spec.children()
stack.extend(reversed(children))
assert spec.num_nodes == sum(child.num_nodes for child in children) + 1
assert treespec.num_nodes == len(nodes)


Expand Down

0 comments on commit 4180be3

Please sign in to comment.