diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d301d9c..7bc8a44a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ * Add whitelist for `pint.UnitRegistry.default_formatter` (Ben Elliston, #258). +* Mark imports in `__all__` as used (kreathon, #172). # 2.4 (2022-05-19) diff --git a/tests/test_imports.py b/tests/test_imports.py index 3efb218f..af004d33 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -209,6 +209,93 @@ def test_import_alias_use_alias(v): check(v.unused_vars, []) +def test_import_with__all__(v): + v.scan( + """\ +# define.py +class Foo: + pass + +class Bar: + pass + +# main.py +from define import Foo, Bar + +__all__ = ["Foo"] + +""" + ) + check(v.defined_imports, ["Foo", "Bar"]) + check(v.unused_imports, ["Bar"]) + + +def test_import_with__all__normal_reference(v): + v.scan( + """\ +# define.py +class Foo: + pass + +class Bar: + pass + +# main.py +from define import Foo, Bar + +__all__ = [Foo] + +""" + ) + check(v.defined_imports, ["Foo", "Bar"]) + check(v.unused_imports, ["Bar"]) + + +def test_import_with__all__string(v): + v.scan( + """\ +# define.py +class Foo: + pass + +class Bar: + pass + +# main.py +from define import Foo, Bar + +__all__ = "Foo" + +""" + ) + check(v.defined_imports, ["Foo", "Bar"]) + # __all__ is not a list or tuple, so Foo is unused. + check(v.unused_imports, ["Foo", "Bar"]) + + +def test_import_with__all__assign_other_module(v): + v.scan( + """\ +# define.py +class Foo: + pass + +class Bar: + pass + +# main.py +import define +from define import Foo, Bar + +define.__all__ = ["Foo"] + +""" + ) + check(v.defined_imports, ["define", "Foo", "Bar"]) + # Only assignments to __all__ of the current module are covered. + check(v.unused_imports, ["Foo", "Bar"]) + + def test_ignore_init_py_files(v): v.scan( """\ diff --git a/vulture/core.py b/vulture/core.py index 8111fbf2..3f716f72 100644 --- a/vulture/core.py +++ b/vulture/core.py @@ -53,6 +53,15 @@ def _is_test_file(filename): ) +def _assigns_special_variable__all__(node): + assert isinstance(node, ast.Assign) + return isinstance(node.value, (ast.List, ast.Tuple)) and any( + target.id == "__all__" + for target in node.targets + if isinstance(target, ast.Name) + ) + + def _ignore_class(filename, class_name): return _is_test_file(filename) and "Test" in class_name @@ -616,6 +625,13 @@ def visit_Name(self, node): elif isinstance(node.ctx, (ast.Param, ast.Store)): self._define_variable(node.id, node) + def visit_Assign(self, node): + if _assigns_special_variable__all__(node): + assert isinstance(node.value, (ast.List, ast.Tuple)) + for elt in node.value.elts: + if isinstance(elt, ast.Str): + self.used_names.add(elt.s) + def visit_While(self, node): self._handle_conditional_node(node, "while")