From 4180be39415ee7c5b5ed5d668f0c4aac49ecaa23 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 6 Oct 2023 13:59:37 +0800 Subject: [PATCH] test: update tests for braodcast --- tests/test_ops.py | 204 ++++++++++++++++++++++++++++++++---- tests/test_prefix_errors.py | 32 ++++++ tests/test_treespec.py | 6 +- 3 files changed, 218 insertions(+), 24 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 2979309b..7b7884e7 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -806,6 +806,16 @@ def add_leaves(p, x): assert leaves == [1, 2, 3, None, 4] +def test_tree_replace_nones(): + sentinel = object() + assert optree.tree_replace_nones(sentinel, {'a': 1, 'b': None, 'c': (2, None)}) == { + 'a': 1, + 'b': sentinel, + 'c': (2, sentinel), + } + assert optree.tree_replace_nones(sentinel, None) == sentinel + + @parametrize( tree=TREES, none_is_leaf=[False, True], @@ -931,36 +941,39 @@ class MyExtraDict(MyAnotherDict): def test_tree_broadcast_prefix(): - assert optree.tree_broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1] - assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3] + assert optree.tree_broadcast_prefix(1, [2, 3, 4]) == [1, 1, 1] + assert optree.tree_broadcast_prefix([1, 2, 3], [4, 5, 6]) == [1, 2, 3] with pytest.raises( ValueError, - match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'), + match=re.escape('list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].'), ): - optree.tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) - assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, (3, 3)] - assert optree.tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) == [ + optree.tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) + assert optree.tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) == [1, 2, (3, 3)] + assert optree.tree_broadcast_prefix( + [1, 2, 3], + [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], + ) == [ 1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}, ] assert optree.tree_broadcast_prefix( [1, 2, 3], - [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], + [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True, ) == [1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}] def test_broadcast_prefix(): - assert optree.broadcast_prefix(1, [1, 2, 3]) == [1, 1, 1] - assert optree.broadcast_prefix([1, 2, 3], [1, 2, 3]) == [1, 2, 3] + assert optree.broadcast_prefix(1, [2, 3, 4]) == [1, 1, 1] + assert optree.broadcast_prefix([1, 2, 3], [4, 5, 6]) == [1, 2, 3] with pytest.raises( ValueError, - match=re.escape('list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].'), + match=re.escape('list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].'), ): - optree.broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) - assert optree.broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) == [1, 2, 3, 3] - assert optree.broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) == [ + optree.broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) + assert optree.broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) == [1, 2, 3, 3] + assert optree.broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) == [ 1, 2, 3, @@ -969,19 +982,166 @@ def test_broadcast_prefix(): ] assert optree.broadcast_prefix( [1, 2, 3], - [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], + [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True, ) == [1, 2, 3, 3, 3, 3] -def test_tree_replace_nones(): - sentinel = object() - assert optree.tree_replace_nones(sentinel, {'a': 1, 'b': None, 'c': (2, None)}) == { - 'a': 1, - 'b': sentinel, - 'c': (2, sentinel), - } - assert optree.tree_replace_nones(sentinel, None) == sentinel +def test_tree_broadcast_common(): + assert optree.tree_broadcast_common(1, [2, 3, 4]) == ([1, 1, 1], [2, 3, 4]) + assert optree.tree_broadcast_common([1, 2, 3], [4, 5, 6]) == ([1, 2, 3], [4, 5, 6]) + with pytest.raises( + ValueError, + match=re.escape('list arity mismatch; expected: 3, got: 4.'), + ): + optree.tree_broadcast_common([1, 2, 3], [1, 2, 3, 4]) + assert optree.tree_broadcast_common([1, 2, 3], [4, 5, (6, 7)]) == ( + [1, 2, (3, 3)], + [4, 5, (6, 7)], + ) + assert optree.tree_broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) == ( + [1, (2, 3), (4, 4)], + [5, (6, 6), (7, 8)], + ) + assert optree.tree_broadcast_common( + [1, (2, 3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + ) == ( + [1, (2, 3), {'a': 4, 'b': 4, 'c': (None, 4)}], + [5, (6, 6), {'a': 7, 'b': 8, 'c': (None, 9)}], + ) + assert optree.tree_broadcast_common( + [1, OrderedDict(foo=2, bar=3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + ) == ( + [1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (None, 4)}], + [5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}], + ) + assert optree.tree_broadcast_common( + [1, OrderedDict(foo=2, bar=3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + none_is_leaf=True, + ) == ( + [1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (4, 4)}], + [5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}], + ) + assert optree.tree_broadcast_common( + [1, OrderedDict(foo=2, bar=3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + none_is_leaf=True, + ) == ( + [1, OrderedDict(foo=2, bar=3), {'a': 4, 'b': 4, 'c': (4, 4)}], + [5, OrderedDict(foo=6, bar=6), {'a': 7, 'b': 8, 'c': (None, 9)}], + ) + assert optree.tree_broadcast_common( + [1, OrderedDict(b=4, c=5, a=(2, 3))], + [(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}], + ) == ( + [(1, 1), OrderedDict(b=4, c=(None, 5), a=(2, 3))], + [(6, 7), {'c': (None, 0), 'a': (8, 8), 'b': 9}], + ) + assert optree.tree_broadcast_common( + [1, OrderedDict(b=4, c=5, a=(2, 3))], + [(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}], + none_is_leaf=True, + ) == ( + [(1, 1), OrderedDict(b=4, c=(5, 5), a=(2, 3))], + [(6, 7), {'c': (None, 0), 'a': (8, 8), 'b': 9}], + ) + assert optree.tree_broadcast_common( + [1, {'c': (None, 4), 'b': 3, 'a': 2}], + [(5, 6), OrderedDict(b=9, c=0, a=(7, 8))], + ) == ( + [(1, 1), {'c': (None, 4), 'b': 3, 'a': (2, 2)}], + [(5, 6), OrderedDict(b=9, c=(None, 0), a=(7, 8))], + ) + assert optree.tree_broadcast_common( + [1, {'b': 3, 'a': 2, 'c': (None, 4)}], + [(5, 6), OrderedDict(b=9, c=0, a=(7, 8))], + none_is_leaf=True, + ) == ( + [(1, 1), {'c': (None, 4), 'b': 3, 'a': (2, 2)}], + [(5, 6), OrderedDict(b=9, c=(0, 0), a=(7, 8))], + ) + + +def test_broadcast_common(): + assert optree.broadcast_common(1, [2, 3, 4]) == ([1, 1, 1], [2, 3, 4]) + assert optree.broadcast_common([1, 2, 3], [4, 5, 6]) == ([1, 2, 3], [4, 5, 6]) + with pytest.raises( + ValueError, + match=re.escape('list arity mismatch; expected: 3, got: 4.'), + ): + optree.broadcast_common([1, 2, 3], [1, 2, 3, 4]) + assert optree.broadcast_common([1, 2, 3], [4, 5, (6, 7)]) == ( + [1, 2, 3, 3], + [4, 5, 6, 7], + ) + assert optree.broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) == ( + [1, 2, 3, 4, 4], + [5, 6, 6, 7, 8], + ) + assert optree.broadcast_common( + [1, (2, 3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + ) == ( + [1, 2, 3, 4, 4, 4], + [5, 6, 6, 7, 8, 9], + ) + assert optree.broadcast_common( + [1, (2, 3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + none_is_leaf=True, + ) == ( + [1, 2, 3, 4, 4, 4, 4], + [5, 6, 6, 7, 8, None, 9], + ) + assert optree.broadcast_common( + [1, OrderedDict(foo=2, bar=3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + ) == ( + [1, 2, 3, 4, 4, 4], + [5, 6, 6, 7, 8, 9], + ) + assert optree.broadcast_common( + [1, OrderedDict(foo=2, bar=3), 4], + [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], + none_is_leaf=True, + ) == ( + [1, 2, 3, 4, 4, 4, 4], + [5, 6, 6, 7, 8, None, 9], + ) + + assert optree.broadcast_common( + [1, OrderedDict(b=4, c=5, a=(2, 3))], + [(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}], + ) == ( + [1, 1, 4, 5, 2, 3], + [6, 7, 9, 0, 8, 8], + ) + assert optree.broadcast_common( + [1, OrderedDict(b=4, c=5, a=(2, 3))], + [(6, 7), {'c': (None, 0), 'a': 8, 'b': 9}], + none_is_leaf=True, + ) == ( + [1, 1, 4, 5, 5, 2, 3], + [6, 7, 9, None, 0, 8, 8], + ) + assert optree.broadcast_common( + [1, {'c': (None, 4), 'b': 3, 'a': 2}], + [(5, 6), OrderedDict(b=9, c=0, a=(7, 8))], + ) == ( + [1, 1, 2, 2, 3, 4], + [5, 6, 7, 8, 9, 0], + ) + assert optree.broadcast_common( + [1, {'b': 3, 'a': 2, 'c': (None, 4)}], + [(5, 6), OrderedDict(b=9, c=0, a=(7, 8))], + none_is_leaf=True, + ) == ( + [1, 1, 2, 2, 3, None, 4], + [5, 6, 7, 8, 9, 0, 0], + ) def test_tree_reduce(): diff --git a/tests/test_prefix_errors.py b/tests/test_prefix_errors.py index cb935f95..e43479ec 100644 --- a/tests/test_prefix_errors.py +++ b/tests/test_prefix_errors.py @@ -81,12 +81,20 @@ def test_different_types(): lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) optree.tree_map_(lambda x, y: None, lhs, rhs) assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) lhs, rhs = {'a': 1, 'b': 2}, defaultdict(int, {'a': 1, 'b': [2, 3]}) lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) optree.tree_map_(lambda x, y: None, lhs, rhs) assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) @@ -266,30 +274,50 @@ def test_different_metadata(): lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore key ordering assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) lhs, rhs = defaultdict(list, {'a': 1, 'b': 2}), defaultdict(set, {'b': [4, 5], 'a': 3}) lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore default factory assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) lhs, rhs = {'a': 1, 'b': 2}, defaultdict(list, {'b': [4, 5], 'a': 3}) lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore dictionary types assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) lhs, rhs = OrderedDict({'b': 5, 'a': 4}), {'a': 1, 'b': [2, 3]} lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore dictionary types assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) lhs, rhs = deque([1, 2], maxlen=None), deque([3, [4, 5]], maxlen=3) lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) optree.tree_map_(lambda x, y: None, lhs, rhs) # ignore maxlen assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) lhs, rhs = FlatCache([None, 1]), FlatCache(1) @@ -432,6 +460,10 @@ def test_no_errors(): optree.tree_map_(lambda x, y: None, lhs, rhs) lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) assert lhs_treespec.is_prefix(rhs_treespec) + assert ( + len(lhs_treespec.broadcast_to(rhs_treespec, range(lhs_treespec.num_leaves))) + == rhs_treespec.num_leaves + ) () = optree.prefix_errors(lhs, rhs) diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 490b20e5..93ccebd1 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -333,7 +333,7 @@ def test_treespec_compose_children(tree, inner_tree, none_is_leaf, namespace): inner_treespec.num_nodes * treespec.num_leaves ) assert composed_treespec.num_nodes == expected_nodes - leaves = [1] * expected_leaves + leaves = list(range(expected_leaves)) composed = optree.tree_unflatten(composed_treespec, leaves) assert leaves == optree.tree_leaves( composed, @@ -569,7 +569,9 @@ def test_treespec_num_nodes(tree, none_is_leaf, namespace): while stack: spec = stack.pop() nodes.append(spec) - stack.extend(reversed(spec.children())) + children = spec.children() + stack.extend(reversed(children)) + assert spec.num_nodes == sum(child.num_nodes for child in children) + 1 assert treespec.num_nodes == len(nodes)