diff --git a/tests/test_more.py b/tests/test_more.py index 30e5a6a..440c1be 100644 --- a/tests/test_more.py +++ b/tests/test_more.py @@ -1,3 +1,4 @@ +import ast from collections.abc import Mapping from unification import var @@ -29,6 +30,14 @@ def test_unify_object(): assert stream_eval(_unify_object(Foo(1, 2), Foo(1, x), {})) == {x: 2} +def test_unify_nonstandard_object(): + _unify.add((ast.AST, ast.AST, Mapping), _unify_object) + x = var() + assert unify(ast.Num(n=1), ast.Num(n=1), {}) == {} + assert unify(ast.Num(n=1), ast.Num(n=2), {}) is False + assert unify(ast.Num(n=1), ast.Num(n=x), {}) == {x: 1} + + def test_reify_object(): x = var() obj = stream_eval(_reify_object(Foo(1, x), {x: 4})) @@ -39,6 +48,14 @@ def test_reify_object(): assert stream_eval(_reify_object(f, {})) is f +def test_reify_nonstandard_object(): + _reify.add((ast.AST, Mapping), _reify_object) + x = var() + assert reify(ast.Num(n=1), {}).n == 1 + assert reify(ast.Num(n=x), {}).n == x + assert reify(ast.Num(n=x), {x: 2}).n == 2 + + def test_reify_slots(): class SlotsObject(object): __slots__ = ["myattr"] diff --git a/unification/more.py b/unification/more.py index 0922284..071e638 100644 --- a/unification/more.py +++ b/unification/more.py @@ -53,7 +53,7 @@ def _reify_object(o, s): def _reify_object_dict(o, s): - obj = object.__new__(type(o)) + obj = type(o).__new__(type(o)) d = yield _reify(o.__dict__, s)