Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
hakancelikdev committed Jan 7, 2024
1 parent fae6ae7 commit 563fb26
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 114 deletions.
13 changes: 7 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ all: lint test clean

.PHONY: dev
dev:
pip install -e .[test]
pip install pre-commit
python3.10 -m venv venv
source venv/bin/activate; pip install pip -U; pip install -e .[test]
source venv/bin/activate; pip install pre-commit

.PHONY: lint
lint:
git add .
pre-commit run --all-files
source venv/bin/activate; pre-commit run

.PHONY: test
test:
pytest tests -x -v --disable-warnings
source venv/bin/activate; pytest tests -x -v --disable-warnings

.PHONY: tox
tox:
Expand Down Expand Up @@ -67,8 +68,8 @@ publish:

.PHONY: docs
docs:
pip install -e .[docs]
mkdocs serve
source venv/bin/activate; pip install -e .[docs]
source venv/bin/activate; mkdocs serve

.PHONY: sync-main
sync-main:
Expand Down
48 changes: 40 additions & 8 deletions src/unimport/analyzers/import_statement.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import ast
import typing

import unimport.constants as C
from unimport import utils
from unimport.analyzers.decarators import generic_visit, skip_import
from unimport.analyzers.importable import ImportableAnalyzer
from unimport.analyzers.importable import ImportableNameAnalyzer, SuggestionNameAnalyzer
from unimport.analyzers.utils import set_tree_parents
from unimport.statement import Import, ImportFrom, Name, Scope

__all__ = ("ImportAnalyzer",)
Expand All @@ -13,8 +17,10 @@ class ImportAnalyzer(ast.NodeVisitor):
__slots__ = (
"source",
"include_star_import",
"any_import_error",
"defined_names",
"any_import_error",
"if_names",
"orelse_names",
)

IGNORE_MODULES_IMPORTS = ("__future__",)
Expand All @@ -35,16 +41,14 @@ def __init__(
def traverse(self, tree) -> None:
self.visit(tree)

def scope_analysis(self, node):
def visit_def(self, node):
Scope.add_current_scope(node)

self.generic_visit(node)

Scope.remove_current_scope()

visit_ClassDef = scope_analysis
visit_FunctionDef = scope_analysis
visit_AsyncFunctionDef = scope_analysis
visit_ClassDef = visit_FunctionDef = visit_AsyncFunctionDef = visit_def

@generic_visit
@skip_import
Expand Down Expand Up @@ -105,7 +109,35 @@ def visit_Try(self, node: ast.Try) -> None:

self.any_import_error = False

@classmethod
def iget_importable_name(cls, package: str) -> typing.Iterator[str]:
if utils.is_std(package):
yield from utils.get_module_dir(package)

elif source := utils.get_source(package):
try:
tree = ast.parse(source)
except SyntaxError:
pass
else:
importable_name_analyzer = ImportableNameAnalyzer()
importable_name_analyzer.traverse(tree)
if importable_name_analyzer.importable_nodes:
for node in importable_name_analyzer.importable_nodes:
yield node.value
else:
suggestion_name_analyzer = SuggestionNameAnalyzer()
set_tree_parents(tree)
suggestion_name_analyzer.traverse(tree)
for node in suggestion_name_analyzer.suggestions_nodes: # type: ignore[assignment]
if isinstance(node, ast.Name):
yield node.id
elif isinstance(node, ast.alias):
yield node.asname or node.name
elif isinstance(node, C.DEF_TUPLE):
yield node.name

def get_suggestions(self, package: str) -> list[str]:
names = set(map(lambda name: name.name.split(".")[0], Name.names))
from_names = ImportableAnalyzer.get_names(package)
return sorted(from_names & (names - self.defined_names))
from_names = self.iget_importable_name(package)
return sorted(set(from_names) & (names - self.defined_names))
143 changes: 60 additions & 83 deletions src/unimport/analyzers/importable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,35 @@

import ast

from unimport import constants as C
from unimport import typing as T
from unimport import utils
import unimport.constants as C
import unimport.typing as T
from unimport.analyzers.decarators import generic_visit
from unimport.analyzers.utils import first_parent_match, set_tree_parents
from unimport.statement import Name, Scope
from unimport.analyzers.utils import first_parent_match

__all__ = (
"ImportableNameAnalyzer",
"SuggestionNameAnalyzer",
)

__all__ = ("ImportableAnalyzer",)
from unimport.statement import Name, Scope


class ImportableAnalyzer(ast.NodeVisitor):
__slots__ = (
"importable_nodes",
"suggestions_nodes",
)
class ImportableNameAnalyzer(ast.NodeVisitor):
__slots__ = ("importable_nodes",)

def __init__(self) -> None:
self.importable_nodes: list[ast.Constant] = [] # nodes on the __all__ list
self.suggestions_nodes: list[T.ASTImportableT] = [] # nodes on the CFN

def traverse(self, tree):
self.visit(tree)

for node in self.importable_nodes:
Name.register(lineno=node.lineno, name=node.value, node=node, is_all=True)

self.clear()

def visit_CFN(self, node: T.CFNT) -> None:
Scope.add_current_scope(node)

if not first_parent_match(node, C.DEF_TUPLE):
self.suggestions_nodes.append(node)

self.generic_visit(node)

Scope.remove_current_scope()

visit_ClassDef = visit_CFN
visit_FunctionDef = visit_CFN
visit_AsyncFunctionDef = visit_CFN

@generic_visit
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
self.suggestions_nodes.append(alias)

@generic_visit
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if not node.names[0].name == "*":
for alias in node.names:
self.suggestions_nodes.append(alias)

@generic_visit
def visit_Assign(self, node: ast.Assign) -> None:
if getattr(node.targets[0], "id", None) == "__all__" and isinstance(node.value, (ast.List, ast.Tuple, ast.Set)):
for item in node.value.elts:
if isinstance(item, ast.Constant) and isinstance(item.value, str):
self.importable_nodes.append(item)

for target in node.targets: # we only get assigned names
if isinstance(target, (ast.Name, ast.Attribute)):
self.suggestions_nodes.append(target)

@generic_visit
def visit_Expr(self, node: ast.Expr) -> None:
if (
Expand All @@ -86,41 +51,53 @@ def visit_Expr(self, node: ast.Expr) -> None:
if isinstance(item, ast.Constant) and isinstance(item.value, str):
self.importable_nodes.append(item)

@classmethod
def get_names(cls, package: str) -> frozenset[str]:
if utils.is_std(package):
return utils.get_dir(package)

source = utils.get_source(package)
if source:
try:
tree = ast.parse(source)
except SyntaxError:
return frozenset()
else:
visitor = cls()
set_tree_parents(tree)
visitor.visit(tree)
return visitor.get_all() or visitor.get_suggestion()
return frozenset()

def get_all(self) -> frozenset[str]:
names = set()

class SuggestionNameAnalyzer(ast.NodeVisitor):
__slots__ = ("suggestions_nodes",)

def __init__(self) -> None:
self.suggestions_nodes: list[T.ASTImportableT] = [] # nodes on the CFN

def traverse(self, tree):
self.visit(tree)

@generic_visit
def visit_def(self, node: T.CFNT) -> None:
if not first_parent_match(node, C.DEF_TUPLE):
self.suggestions_nodes.append(node)

visit_ClassDef = visit_FunctionDef = visit_AsyncFunctionDef = visit_def

@generic_visit
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
self.suggestions_nodes.append(alias)

@generic_visit
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if not node.names[0].name == "*":
for alias in node.names:
self.suggestions_nodes.append(alias)

@generic_visit
def visit_Assign(self, node: ast.Assign) -> None:
for target in node.targets: # we only get assigned names
if isinstance(target, (ast.Name, ast.Attribute)):
self.suggestions_nodes.append(target)


class ImportableNameWithScopeAnalyzer(ImportableNameAnalyzer):
def traverse(self, tree):
super().traverse(tree)

for node in self.importable_nodes:
names.add(node.value)
return frozenset(names)

def get_suggestion(self) -> frozenset[str]:
names = set()
for node in self.suggestions_nodes: # type: ignore
if isinstance(node, ast.Name):
names.add(node.id)
elif isinstance(node, ast.alias):
names.add(node.asname or node.name)
elif isinstance(node, C.DEF_TUPLE):
names.add(node.name)
return frozenset(names)

def clear(self):
self.importable_nodes.clear()
self.suggestions_nodes.clear()
Name.register(lineno=node.lineno, name=node.value, node=node, is_all=True)

def visit_def(self, node: T.CFNT) -> None:
Scope.add_current_scope(node)

self.generic_visit(node)

Scope.remove_current_scope()

visit_ClassDef = visit_FunctionDef = visit_AsyncFunctionDef = visit_def
4 changes: 2 additions & 2 deletions src/unimport/analyzers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

from unimport.analyzers.import_statement import ImportAnalyzer
from unimport.analyzers.importable import ImportableAnalyzer
from unimport.analyzers.importable import ImportableNameWithScopeAnalyzer
from unimport.analyzers.name import NameAnalyzer
from unimport.analyzers.utils import get_defined_names, set_tree_parents
from unimport.statement import Import, ImportFrom, Name, Scope
Expand Down Expand Up @@ -38,7 +38,7 @@ def traverse(self) -> None:

NameAnalyzer().traverse(tree) # name analyzers

ImportableAnalyzer().traverse(tree) # importable analyzers for collect in __all__
ImportableNameWithScopeAnalyzer().traverse(tree) # importable analyzers for collect in __all__

ImportAnalyzer( # import analyzers
source=self.source, include_star_import=self.include_star_import, defined_names=get_defined_names(tree)
Expand Down
2 changes: 1 addition & 1 deletion src/unimport/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@
}
)

BUILTIN_MODULE_NAMES = frozenset(sys.builtin_module_names)
BUILTIN_MODULE_NAMES = sys.builtin_module_names
STDLIB_PATH = sysconfig.get_paths()["stdlib"]
3 changes: 1 addition & 2 deletions src/unimport/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ast
import dataclasses
import operator
import typing

__all__ = ("Import", "ImportFrom", "Name", "Scope")
Expand All @@ -20,7 +19,7 @@ class Import:
node: ast.Import | ast.ImportFrom = dataclasses.field(init=False, repr=False, compare=False)

def __len__(self) -> int:
return operator.length_hint(self.name.split("."))
return len(self.name.split("."))

def is_match_sub_packages(self, name_name: str) -> bool:
return self.name.split(".")[0] == name_name
Expand Down
2 changes: 1 addition & 1 deletion src/unimport/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"CSTImportT",
)

ASTImportableT = typing.Union[ast.AsyncFunctionDef, ast.Attribute, ast.ClassDef, ast.FunctionDef, ast.Name, ast.alias]
ASTImportableT = typing.Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef, ast.Attribute, ast.Name, ast.alias]
ASTFunctionT = typing.TypeVar("ASTFunctionT", ast.FunctionDef, ast.AsyncFunctionDef)

CFNT = typing.TypeVar("CFNT", ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef, ast.Name)
Expand Down
Loading

0 comments on commit 563fb26

Please sign in to comment.