diff --git a/Makefile b/Makefile index fe66dbe..bcdbba9 100644 --- a/Makefile +++ b/Makefile @@ -2,17 +2,18 @@ all: lint test clean .PHONY: dev dev: - pip install -e .[test] - pip install pre-commit + python3.10 -m venv venv + source venv/bin/activate; pip install pip -U; pip install -e .[test] + source venv/bin/activate; pip install pre-commit .PHONY: lint lint: git add . - pre-commit run --all-files + source venv/bin/activate; pre-commit run .PHONY: test test: - pytest tests -x -v --disable-warnings + source venv/bin/activate; pytest tests -x -v --disable-warnings .PHONY: tox tox: @@ -67,8 +68,8 @@ publish: .PHONY: docs docs: - pip install -e .[docs] - mkdocs serve + source venv/bin/activate; pip install -e .[docs] + source venv/bin/activate; mkdocs serve .PHONY: sync-main sync-main: diff --git a/src/unimport/analyzers/import_statement.py b/src/unimport/analyzers/import_statement.py index 3b3c00d..ed233e2 100644 --- a/src/unimport/analyzers/import_statement.py +++ b/src/unimport/analyzers/import_statement.py @@ -1,9 +1,13 @@ from __future__ import annotations import ast +import typing +import unimport.constants as C +from unimport import utils from unimport.analyzers.decarators import generic_visit, skip_import -from unimport.analyzers.importable import ImportableAnalyzer +from unimport.analyzers.importable import ImportableNameAnalyzer, SuggestionNameAnalyzer +from unimport.analyzers.utils import set_tree_parents from unimport.statement import Import, ImportFrom, Name, Scope __all__ = ("ImportAnalyzer",) @@ -13,8 +17,10 @@ class ImportAnalyzer(ast.NodeVisitor): __slots__ = ( "source", "include_star_import", - "any_import_error", "defined_names", + "any_import_error", + "if_names", + "orelse_names", ) IGNORE_MODULES_IMPORTS = ("__future__",) @@ -35,16 +41,14 @@ def __init__( def traverse(self, tree) -> None: self.visit(tree) - def scope_analysis(self, node): + def visit_def(self, node): Scope.add_current_scope(node) self.generic_visit(node) Scope.remove_current_scope() - visit_ClassDef = scope_analysis - visit_FunctionDef = scope_analysis - visit_AsyncFunctionDef = scope_analysis + visit_ClassDef = visit_FunctionDef = visit_AsyncFunctionDef = visit_def @generic_visit @skip_import @@ -105,7 +109,35 @@ def visit_Try(self, node: ast.Try) -> None: self.any_import_error = False + @classmethod + def iget_importable_name(cls, package: str) -> typing.Iterator[str]: + if utils.is_std(package): + yield from utils.get_module_dir(package) + + elif source := utils.get_source(package): + try: + tree = ast.parse(source) + except SyntaxError: + pass + else: + importable_name_analyzer = ImportableNameAnalyzer() + importable_name_analyzer.traverse(tree) + if importable_name_analyzer.importable_nodes: + for node in importable_name_analyzer.importable_nodes: + yield node.value + else: + suggestion_name_analyzer = SuggestionNameAnalyzer() + set_tree_parents(tree) + suggestion_name_analyzer.traverse(tree) + for node in suggestion_name_analyzer.suggestions_nodes: # type: ignore[assignment] + if isinstance(node, ast.Name): + yield node.id + elif isinstance(node, ast.alias): + yield node.asname or node.name + elif isinstance(node, C.DEF_TUPLE): + yield node.name + def get_suggestions(self, package: str) -> list[str]: names = set(map(lambda name: name.name.split(".")[0], Name.names)) - from_names = ImportableAnalyzer.get_names(package) - return sorted(from_names & (names - self.defined_names)) + from_names = self.iget_importable_name(package) + return sorted(set(from_names) & (names - self.defined_names)) diff --git a/src/unimport/analyzers/importable.py b/src/unimport/analyzers/importable.py index 9efdeba..cdacbd5 100644 --- a/src/unimport/analyzers/importable.py +++ b/src/unimport/analyzers/importable.py @@ -2,59 +2,28 @@ import ast -from unimport import constants as C -from unimport import typing as T -from unimport import utils +import unimport.constants as C +import unimport.typing as T from unimport.analyzers.decarators import generic_visit -from unimport.analyzers.utils import first_parent_match, set_tree_parents -from unimport.statement import Name, Scope +from unimport.analyzers.utils import first_parent_match + +__all__ = ( + "ImportableNameAnalyzer", + "SuggestionNameAnalyzer", +) -__all__ = ("ImportableAnalyzer",) +from unimport.statement import Name, Scope -class ImportableAnalyzer(ast.NodeVisitor): - __slots__ = ( - "importable_nodes", - "suggestions_nodes", - ) +class ImportableNameAnalyzer(ast.NodeVisitor): + __slots__ = ("importable_nodes",) def __init__(self) -> None: self.importable_nodes: list[ast.Constant] = [] # nodes on the __all__ list - self.suggestions_nodes: list[T.ASTImportableT] = [] # nodes on the CFN def traverse(self, tree): self.visit(tree) - for node in self.importable_nodes: - Name.register(lineno=node.lineno, name=node.value, node=node, is_all=True) - - self.clear() - - def visit_CFN(self, node: T.CFNT) -> None: - Scope.add_current_scope(node) - - if not first_parent_match(node, C.DEF_TUPLE): - self.suggestions_nodes.append(node) - - self.generic_visit(node) - - Scope.remove_current_scope() - - visit_ClassDef = visit_CFN - visit_FunctionDef = visit_CFN - visit_AsyncFunctionDef = visit_CFN - - @generic_visit - def visit_Import(self, node: ast.Import) -> None: - for alias in node.names: - self.suggestions_nodes.append(alias) - - @generic_visit - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - if not node.names[0].name == "*": - for alias in node.names: - self.suggestions_nodes.append(alias) - @generic_visit def visit_Assign(self, node: ast.Assign) -> None: if getattr(node.targets[0], "id", None) == "__all__" and isinstance(node.value, (ast.List, ast.Tuple, ast.Set)): @@ -62,10 +31,6 @@ def visit_Assign(self, node: ast.Assign) -> None: if isinstance(item, ast.Constant) and isinstance(item.value, str): self.importable_nodes.append(item) - for target in node.targets: # we only get assigned names - if isinstance(target, (ast.Name, ast.Attribute)): - self.suggestions_nodes.append(target) - @generic_visit def visit_Expr(self, node: ast.Expr) -> None: if ( @@ -86,41 +51,53 @@ def visit_Expr(self, node: ast.Expr) -> None: if isinstance(item, ast.Constant) and isinstance(item.value, str): self.importable_nodes.append(item) - @classmethod - def get_names(cls, package: str) -> frozenset[str]: - if utils.is_std(package): - return utils.get_dir(package) - - source = utils.get_source(package) - if source: - try: - tree = ast.parse(source) - except SyntaxError: - return frozenset() - else: - visitor = cls() - set_tree_parents(tree) - visitor.visit(tree) - return visitor.get_all() or visitor.get_suggestion() - return frozenset() - - def get_all(self) -> frozenset[str]: - names = set() + +class SuggestionNameAnalyzer(ast.NodeVisitor): + __slots__ = ("suggestions_nodes",) + + def __init__(self) -> None: + self.suggestions_nodes: list[T.ASTImportableT] = [] # nodes on the CFN + + def traverse(self, tree): + self.visit(tree) + + @generic_visit + def visit_def(self, node: T.CFNT) -> None: + if not first_parent_match(node, C.DEF_TUPLE): + self.suggestions_nodes.append(node) + + visit_ClassDef = visit_FunctionDef = visit_AsyncFunctionDef = visit_def + + @generic_visit + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + self.suggestions_nodes.append(alias) + + @generic_visit + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if not node.names[0].name == "*": + for alias in node.names: + self.suggestions_nodes.append(alias) + + @generic_visit + def visit_Assign(self, node: ast.Assign) -> None: + for target in node.targets: # we only get assigned names + if isinstance(target, (ast.Name, ast.Attribute)): + self.suggestions_nodes.append(target) + + +class ImportableNameWithScopeAnalyzer(ImportableNameAnalyzer): + def traverse(self, tree): + super().traverse(tree) + for node in self.importable_nodes: - names.add(node.value) - return frozenset(names) - - def get_suggestion(self) -> frozenset[str]: - names = set() - for node in self.suggestions_nodes: # type: ignore - if isinstance(node, ast.Name): - names.add(node.id) - elif isinstance(node, ast.alias): - names.add(node.asname or node.name) - elif isinstance(node, C.DEF_TUPLE): - names.add(node.name) - return frozenset(names) - - def clear(self): - self.importable_nodes.clear() - self.suggestions_nodes.clear() + Name.register(lineno=node.lineno, name=node.value, node=node, is_all=True) + + def visit_def(self, node: T.CFNT) -> None: + Scope.add_current_scope(node) + + self.generic_visit(node) + + Scope.remove_current_scope() + + visit_ClassDef = visit_FunctionDef = visit_AsyncFunctionDef = visit_def diff --git a/src/unimport/analyzers/main.py b/src/unimport/analyzers/main.py index 4f1621c..c9f2854 100644 --- a/src/unimport/analyzers/main.py +++ b/src/unimport/analyzers/main.py @@ -3,7 +3,7 @@ from pathlib import Path from unimport.analyzers.import_statement import ImportAnalyzer -from unimport.analyzers.importable import ImportableAnalyzer +from unimport.analyzers.importable import ImportableNameWithScopeAnalyzer from unimport.analyzers.name import NameAnalyzer from unimport.analyzers.utils import get_defined_names, set_tree_parents from unimport.statement import Import, ImportFrom, Name, Scope @@ -38,7 +38,7 @@ def traverse(self) -> None: NameAnalyzer().traverse(tree) # name analyzers - ImportableAnalyzer().traverse(tree) # importable analyzers for collect in __all__ + ImportableNameWithScopeAnalyzer().traverse(tree) # importable analyzers for collect in __all__ ImportAnalyzer( # import analyzers source=self.source, include_star_import=self.include_star_import, defined_names=get_defined_names(tree) diff --git a/src/unimport/constants.py b/src/unimport/constants.py index ffd1740..b8037f9 100644 --- a/src/unimport/constants.py +++ b/src/unimport/constants.py @@ -80,5 +80,5 @@ } ) -BUILTIN_MODULE_NAMES = frozenset(sys.builtin_module_names) +BUILTIN_MODULE_NAMES = sys.builtin_module_names STDLIB_PATH = sysconfig.get_paths()["stdlib"] diff --git a/src/unimport/statement.py b/src/unimport/statement.py index 765a861..29e655e 100644 --- a/src/unimport/statement.py +++ b/src/unimport/statement.py @@ -2,7 +2,6 @@ import ast import dataclasses -import operator import typing __all__ = ("Import", "ImportFrom", "Name", "Scope") @@ -20,7 +19,7 @@ class Import: node: ast.Import | ast.ImportFrom = dataclasses.field(init=False, repr=False, compare=False) def __len__(self) -> int: - return operator.length_hint(self.name.split(".")) + return len(self.name.split(".")) def is_match_sub_packages(self, name_name: str) -> bool: return self.name.split(".")[0] == name_name diff --git a/src/unimport/typing.py b/src/unimport/typing.py index 00d71ac..cc7380f 100644 --- a/src/unimport/typing.py +++ b/src/unimport/typing.py @@ -11,7 +11,7 @@ "CSTImportT", ) -ASTImportableT = typing.Union[ast.AsyncFunctionDef, ast.Attribute, ast.ClassDef, ast.FunctionDef, ast.Name, ast.alias] +ASTImportableT = typing.Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef, ast.Attribute, ast.Name, ast.alias] ASTFunctionT = typing.TypeVar("ASTFunctionT", ast.FunctionDef, ast.AsyncFunctionDef) CFNT = typing.TypeVar("CFNT", ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef, ast.Name) diff --git a/src/unimport/utils.py b/src/unimport/utils.py index f40866a..0a209e1 100644 --- a/src/unimport/utils.py +++ b/src/unimport/utils.py @@ -1,12 +1,12 @@ from __future__ import annotations import difflib -import functools import importlib.machinery import importlib.util import re import tokenize import typing +from functools import lru_cache from pathlib import Path from pathspec.patterns import GitWildMatchPattern @@ -14,7 +14,7 @@ import unimport.constants as C __all__ = ( - "get_dir", + "get_module_dir", "get_source", "get_spec", "is_std", @@ -27,16 +27,16 @@ ) -@functools.lru_cache(maxsize=128) -def get_dir(package: str) -> frozenset[str]: +@lru_cache(maxsize=128) +def get_module_dir(package: str) -> list[str] | None: try: module = importlib.import_module(package) + return dir(module) except (ImportError, AttributeError, TypeError, ValueError): - return frozenset() - return frozenset(dir(module)) + return None -@functools.lru_cache(maxsize=128) +@lru_cache(maxsize=128) def get_source(package: str) -> str | None: spec = get_spec(package) if spec is not None: @@ -48,7 +48,7 @@ def get_source(package: str) -> str | None: return None -@functools.lru_cache(maxsize=128) +@lru_cache(maxsize=128) def get_spec(package: str) -> importlib.machinery.ModuleSpec | None: try: return importlib.util.find_spec(package) @@ -56,7 +56,7 @@ def get_spec(package: str) -> importlib.machinery.ModuleSpec | None: return None -@functools.lru_cache(maxsize=128) +@lru_cache(maxsize=128) def is_std(package: str) -> bool: """Returns True if package module came with from Python.""" if package in C.BUILTIN_MODULE_NAMES: @@ -74,7 +74,7 @@ def is_std(package: str) -> bool: return False -@functools.lru_cache(maxsize=3) +@lru_cache(maxsize=12) def action_to_bool(action: str) -> bool: """Convert a string representation of truth to true (True) or false (False). @@ -160,7 +160,7 @@ def diff(*, source: str, refactor_result: str, fromfile: Path = None) -> tuple[s ) -@functools.lru_cache(maxsize=128) +@lru_cache(maxsize=6) def return_exit_code(*, is_unused_imports: bool, is_syntax_error: bool, refactor_applied: bool) -> int: """If this function changes, be sure to update this page https://unimport.hakancelik.dev/tutorial/other-useful-features/#exit-code- diff --git a/tests/cases/analyzer/all/module_from_all_set_the_name.py b/tests/cases/analyzer/all/module_from_all_set_the_name.py new file mode 100644 index 0000000..0d221fd --- /dev/null +++ b/tests/cases/analyzer/all/module_from_all_set_the_name.py @@ -0,0 +1,19 @@ +from typing import List, Union + +from unimport.statement import Import, ImportFrom, Name + +__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"] + + +NAMES: List[Name] = [ + Name(lineno=3, name="__all__", is_all=False), + Name(lineno=8, name="Any", is_all=False), + Name(lineno=4, name="Test", is_all=True), + Name(lineno=5, name="Test2", is_all=True), +] +IMPORTS: List[Union[Import, ImportFrom]] = [ + ImportFrom(lineno=1, column=1, name="typing", package="typing", star=True, suggestions=["Any"]), +] +UNUSED_IMPORTS: List[Union[Import, ImportFrom]] = [ + ImportFrom(lineno=1, column=1, name="typing", package="typing", star=True, suggestions=["Any"]), +] diff --git a/tests/cases/refactor/all/module_from_all_set_the_name.py b/tests/cases/refactor/all/module_from_all_set_the_name.py new file mode 100644 index 0000000..faa5c6d --- /dev/null +++ b/tests/cases/refactor/all/module_from_all_set_the_name.py @@ -0,0 +1,8 @@ +from typing import Any + +__all__ = ( + "Test", + "Test2" +) + +Any diff --git a/tests/cases/source/all/module_from_all_set_the_name.py b/tests/cases/source/all/module_from_all_set_the_name.py new file mode 100644 index 0000000..05835ad --- /dev/null +++ b/tests/cases/source/all/module_from_all_set_the_name.py @@ -0,0 +1,8 @@ +from typing import * + +__all__ = ( + "Test", + "Test2" +) + +Any