Skip to content

Commit

Permalink
Deduplicate examples by repr
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Jul 9, 2023
1 parent 69ac026 commit c14725b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion hypothesis-python/src/hypothesis/extra/_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
# The printed examples might include object reprs which are invalid syntax,
# so we parse here and skip over those. If _none_ are valid, there's no patch.
call_nodes = []
for ex, via in failing_examples:
for ex, via in set(failing_examples):
with suppress(Exception):
node = cst.parse_expression(ex)
assert isinstance(node, cst.Call), node
Expand Down
13 changes: 10 additions & 3 deletions hypothesis-python/tests/patching/test_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def test_make_full_patch(tst, example, expected, body, remove):

@pytest.mark.parametrize("n", [0, 1, 2])
def test_invalid_syntax_cases_dropped(n):
tst, example, expected = SIMPLE
example_ls = [example] * n
tst, (ex, via), expected = SIMPLE
example_ls = [(ex.replace("x=1", f"x={x}"), via) for x in range(n)]
example_ls.insert(-1, ("fn(\n x=<__main__.Cls object at 0x>,\n)", FAIL_MSG))

got = get_patch_for(tst, example_ls)
Expand All @@ -158,7 +158,7 @@ def test_invalid_syntax_cases_dropped(n):
where, _, after = got

assert Path(where) == WHERE
assert after.count(expected.lstrip("+")) == n
assert after.count("@example(x=") == n


def test_no_example_for_data_strategy():
Expand All @@ -169,6 +169,13 @@ def test_no_example_for_data_strategy():
assert get_patch_for(fn, [("fn(Foo(data=data(...)))", "msg")]) is not None


def test_deduplicates_examples():
tst, example, expected = SIMPLE
where, _, after = get_patch_for(tst, [example, example])
assert Path(where) == WHERE
assert after.count(expected.lstrip("+")) == 1


def test_irretrievable_callable():
# Check that we return None instead of raising an exception
old_module = fn.__module__
Expand Down

0 comments on commit c14725b

Please sign in to comment.