Skip to content

Commit

Permalink
Do not trigger B901 with explicit Generator return type (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jul 2, 2024
1 parent b15feed commit cfda1e8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 8 deletions.
34 changes: 30 additions & 4 deletions bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,10 +1187,27 @@ def _loop(parent, node):
for child in node.body:
yield from _loop(node, child)

def check_for_b901(self, node):
def check_for_b901(self, node: ast.FunctionDef) -> None:
if node.name == "__await__":
return

# If the user explicitly wrote the 3-argument version of Generator as the
# return annotation, they probably know what they were doing.
if (
node.returns is not None
and isinstance(node.returns, ast.Subscript)
and (
is_name(node.returns.value, "Generator")
or is_name(node.returns.value, "typing.Generator")
or is_name(node.returns.value, "collections.abc.Generator")
)
):
slice = node.returns.slice
if sys.version_info < (3, 9) and isinstance(slice, ast.Index):
slice = slice.value
if isinstance(slice, ast.Tuple) and len(slice.elts) == 3:
return

has_yield = False
return_node = None

Expand All @@ -1204,9 +1221,8 @@ def check_for_b901(self, node):
if isinstance(x, ast.Return) and x.value is not None:
return_node = x

if has_yield and return_node is not None:
self.errors.append(B901(return_node.lineno, return_node.col_offset))
break
if has_yield and return_node is not None:
self.errors.append(B901(return_node.lineno, return_node.col_offset))

# taken from pep8-naming
@classmethod
Expand Down Expand Up @@ -1703,6 +1719,16 @@ def compose_call_path(node):
yield node.id


def is_name(node: ast.expr, name: str) -> bool:
if "." not in name:
return isinstance(node, ast.Name) and node.id == name
else:
if not isinstance(node, ast.Attribute):
return False
rest, attr = name.rsplit(".", maxsplit=1)
return node.attr == attr and is_name(node.value, rest)


def _transform_slice_to_py39(slice: ast.expr | ast.Slice) -> ast.Slice | ast.expr:
"""Transform a py38 style slice to a py39 style slice.
Expand Down
38 changes: 35 additions & 3 deletions tests/b901.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
Should emit:
B901 - on lines 9, 36
B901
"""

def broken():
if True:
return [1, 2, 3]
return [1, 2, 3] # B901

yield 3
yield 2
Expand All @@ -32,7 +32,7 @@ def not_broken3():


def broken2():
return [3, 2, 1]
return [3, 2, 1] # B901

yield from not_broken()

Expand Down Expand Up @@ -75,3 +75,35 @@ class NotBroken9(object):
def __await__(self):
yield from function()
return 42


def broken3():
if True:
return [1, 2, 3] # B901
else:
yield 3


def broken4() -> Iterable[str]:
yield "x"
return ["x"] # B901


def broken5() -> Generator[str]:
yield "x"
return ["x"] # B901


def not_broken10() -> Generator[str, int, float]:
yield "x"
return 1.0


def not_broken11() -> typing.Generator[str, int, float]:
yield "x"
return 1.0


def not_broken12() -> collections.abc.Generator[str, int, float]:
yield "x"
return 1.0
5 changes: 4 additions & 1 deletion tests/test_bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,10 @@ def test_b901(self):
filename = Path(__file__).absolute().parent / "b901.py"
bbc = BugBearChecker(filename=str(filename))
errors = list(bbc.run())
self.assertEqual(errors, self.errors(B901(8, 8), B901(35, 4)))
self.assertEqual(
errors,
self.errors(B901(8, 8), B901(35, 4), B901(82, 8), B901(89, 4), B901(94, 4)),
)

def test_b902(self):
filename = Path(__file__).absolute().parent / "b902.py"
Expand Down

0 comments on commit cfda1e8

Please sign in to comment.