From e52823b7689da323530b21fb820700022c5fc608 Mon Sep 17 00:00:00 2001 From: Bernat Gabor Date: Wed, 20 Feb 2019 12:15:04 +0000 Subject: [PATCH] First implementation Signed-off-by: Bernat Gabor --- mypy/build.py | 32 +- mypy/find_sources.py | 119 ++++--- mypy/main.py | 4 + mypy/modulefinder.py | 31 +- mypy/options.py | 2 + mypy/stub_src_merge.py | 404 +++++++++++++++++++++++ mypy/test/test_merge_stub_src.py | 194 +++++++++++ mypy/test/test_source_finder.py | 145 ++++++++ test-data/unit/check-stub-src-merge.test | 333 +++++++++++++++++++ 9 files changed, 1221 insertions(+), 43 deletions(-) create mode 100644 mypy/stub_src_merge.py create mode 100644 mypy/test/test_merge_stub_src.py create mode 100644 mypy/test/test_source_finder.py create mode 100644 test-data/unit/check-stub-src-merge.test diff --git a/mypy/build.py b/mypy/build.py index bdc942c5e2bea..66f1c088480cc 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -52,6 +52,7 @@ from mypy.options import Options from mypy.parse import parse from mypy.stats import dump_type_stats +from mypy.stub_src_merge import MergeFiles from mypy.types import Type from mypy.version import __version__ from mypy.plugin import Plugin, ChainedPlugin, plugin_types @@ -107,6 +108,8 @@ def __init__(self, sources: List[BuildSource]) -> None: self.source_text_present = True elif source.path: self.source_paths.add(source.path) + if source.merge_with and source.merge_with.path: + self.source_paths.add(source.merge_with.path) else: self.source_modules.add(source.module) @@ -2154,6 +2157,23 @@ def generate_unused_ignore_notes(self) -> None: self.verify_dependencies(suppressed_only=True) self.manager.errors.generate_unused_ignore_notes(self.xpath) + def merge_with(self, state: 'State', errors: Errors) -> None: + self.ancestors = list(set(self.ancestors or []) | set(state.ancestors or [])) + self.child_modules = set(self.child_modules) | set(state.child_modules) + self.dependencies = list(set(self.dependencies) | set(state.dependencies)) + + dep_line_map = {k: -v for k, v in state.dep_line_map.items()} + dep_line_map.update(self.dep_line_map) + self.dep_line_map = dep_line_map + + priorities = {k: -v for k, v in state.priorities.items()} + priorities.update(self.priorities) + self.priorities = priorities + + self.dep_line_map.update({k: -v for k, v in state.dep_line_map.items()}) + if self.tree is not None and state.tree is not None: + MergeFiles(self.tree, state.tree, errors, self.xpath, self.id).run() + # Module import and diagnostic glue @@ -2199,6 +2219,8 @@ def find_module_and_diagnose(manager: BuildManager, file_id = '__builtin__' path = find_module_simple(file_id, manager) if path: + if isinstance(path, tuple): + path = path[-1] # For non-stubs, look at options.follow_imports: # - normal (default) -> fully analyze # - silent -> analyze but silence errors @@ -2274,7 +2296,7 @@ def exist_added_packages(suppressed: List[str], return False -def find_module_simple(id: str, manager: BuildManager) -> Optional[str]: +def find_module_simple(id: str, manager: BuildManager) -> Union[Optional[str], Tuple[str, ...]]: """Find a filesystem path for module `id` or `None` if not found.""" t0 = time.time() x = manager.find_module_cache.find_module(id) @@ -2535,8 +2557,16 @@ def load_graph(sources: List[BuildSource], manager: BuildManager, # Seed the graph with the initial root sources. for bs in sources: try: + stub_state = None + if bs.merge_with: + mw = bs.merge_with + stub_state = State(id=mw.module, path=mw.path, source=mw.text, + manager=manager, root_source=True) + stub_state.parse_file() st = State(id=bs.module, path=bs.path, source=bs.text, manager=manager, root_source=True) + if stub_state is not None: + st.merge_with(stub_state, manager.errors) except ModuleNotFound: continue if st.id in graph: diff --git a/mypy/find_sources.py b/mypy/find_sources.py index 034302cc85d9e..c96cad3000495 100644 --- a/mypy/find_sources.py +++ b/mypy/find_sources.py @@ -1,8 +1,9 @@ """Routines for finding the sources that mypy will check""" +from itertools import tee, filterfalse import os.path -from typing import List, Sequence, Set, Tuple, Optional, Dict +from typing import List, Sequence, Set, Tuple, Optional, Dict, Callable, Iterable from mypy.modulefinder import BuildSource, PYTHON_EXTENSIONS from mypy.fscache import FileSystemCache @@ -19,27 +20,49 @@ class InvalidSourceList(Exception): """Exception indicating a problem in the list of sources given to mypy.""" -def create_source_list(files: Sequence[str], options: Options, - fscache: Optional[FileSystemCache] = None, - allow_empty_dir: bool = False) -> List[BuildSource]: +def partition( + pred: Callable[[str], bool], iterable: Iterable[str] +) -> Tuple[Iterable[str], Iterable[str]]: + """Use a predicate to partition entries into false entries and true entries""" + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) + + +def create_source_list( + files: Sequence[str], + options: Options, + fscache: Optional[FileSystemCache] = None, + allow_empty_dir: bool = False, +) -> List[BuildSource]: """From a list of source files/directories, makes a list of BuildSources. Raises InvalidSourceList on errors. """ fscache = fscache or FileSystemCache() - finder = SourceFinder(fscache) - + finder = SourceFinder(fscache, options.merge_stub_into_src) targets = [] - for f in files: - if f.endswith(PY_EXTENSIONS): + found_targets = set() # type: Set[str] + other, stubs = partition(lambda v: v.endswith(".pyi"), files) + source_then_stubs = list(other) + list(stubs) + for f in source_then_stubs: + if f in found_targets: + continue + base, ext = os.path.splitext(f) + found_targets.add(f) + if ext in PY_EXTENSIONS: # Can raise InvalidSourceList if a directory doesn't have a valid module name. name, base_dir = finder.crawl_up(os.path.normpath(f)) - targets.append(BuildSource(f, name, None, base_dir)) + merge_stub = None # type: Optional[BuildSource] + if options.merge_stub_into_src and ext == ".py": + stub_file = "{}.pyi".format(base) + if os.path.exists(stub_file): + merge_stub = BuildSource(stub_file, name, None, base_dir) + found_targets.add(stub_file) + targets.append(BuildSource(f, name, None, base_dir, merge_with=merge_stub)) elif fscache.isdir(f): sub_targets = finder.expand_dir(os.path.normpath(f)) if not sub_targets and not allow_empty_dir: - raise InvalidSourceList("There are no .py[i] files in directory '{}'" - .format(f)) + raise InvalidSourceList("There are no .py[i] files in directory '{}'".format(f)) targets.extend(sub_targets) else: mod = os.path.basename(f) if options.scripts_are_modules else None @@ -47,25 +70,26 @@ def create_source_list(files: Sequence[str], options: Options, return targets -def keyfunc(name: str) -> Tuple[int, str]: +PY_MAP = {k: i for i, k in enumerate(PY_EXTENSIONS)} + + +def keyfunc(name: str) -> Tuple[str, int]: """Determines sort order for directory listing. The desirable property is foo < foo.pyi < foo.py. """ base, suffix = os.path.splitext(name) - for i, ext in enumerate(PY_EXTENSIONS): - if suffix == ext: - return (i, base) - return (-1, name) + return base, PY_MAP.get(suffix, -1) class SourceFinder: - def __init__(self, fscache: FileSystemCache) -> None: + def __init__(self, fscache: FileSystemCache, merge_stub_into_src: bool) -> None: self.fscache = fscache # A cache for package names, mapping from directory path to module id and base dir self.package_cache = {} # type: Dict[str, Tuple[str, str]] + self.merge_stub_into_src = merge_stub_into_src - def expand_dir(self, arg: str, mod_prefix: str = '') -> List[BuildSource]: + def expand_dir(self, arg: str, mod_prefix: str = "") -> List[BuildSource]: """Convert a directory name to a list of sources to build.""" f = self.get_init_file(arg) if mod_prefix and not f: @@ -79,27 +103,46 @@ def expand_dir(self, arg: str, mod_prefix: str = '') -> List[BuildSource]: sources.append(BuildSource(f, mod_prefix.rstrip('.'), None, base_dir)) names = self.fscache.listdir(arg) names.sort(key=keyfunc) - for name in names: - # Skip certain names altogether - if (name == '__pycache__' or name == 'py.typed' - or name.startswith('.') - or name.endswith(('~', '.pyc', '.pyo'))): - continue - path = os.path.join(arg, name) - if self.fscache.isdir(path): - sub_sources = self.expand_dir(path, mod_prefix + name + '.') - if sub_sources: - seen.add(name) - sources.extend(sub_sources) - else: - base, suffix = os.path.splitext(name) - if base == '__init__': + name_iter = iter(names) + try: + name = next(name_iter, None) + while name is not None: + # Skip certain names altogether + if (name == '__pycache__' or name == 'py.typed' + or name.startswith('.') + or name.endswith(('~', '.pyc', '.pyo'))): continue - if base not in seen and '.' not in base and suffix in PY_EXTENSIONS: - seen.add(base) - src = BuildSource(path, mod_prefix + base, None, base_dir) - sources.append(src) - return sources + path = os.path.join(arg, name) + if self.fscache.isdir(path): + sub_sources = self.expand_dir(path, mod_prefix + name + '.') + if sub_sources: + seen.add(name) + sources.extend(sub_sources) + name = next(name_iter) + else: + base, suffix = os.path.splitext(name) + name = next(name_iter, None) + if base == '__init__': + continue + if base not in seen and '.' not in base and suffix in PY_EXTENSIONS: + seen.add(base) + if name is None: + next_base, next_suffix = None, None + else: + next_base, next_suffix = os.path.splitext(name) + src = BuildSource(path, mod_prefix + base, None, base_dir) + if self.merge_stub_into_src is True and next_base is not None \ + and next_base == base and name is not None: + merge_with = src + src = BuildSource(path=os.path.join(arg, name), + module=mod_prefix + next_base, + merge_with=merge_with, + text=None, + base_dir=base_dir) + sources.append(src) + return sources + except StopIteration: + return sources def crawl_up(self, arg: str) -> Tuple[str, str]: """Given a .py[i] filename, return module and base directory diff --git a/mypy/main.py b/mypy/main.py index ab23f90c7fde1..f66d85f7ee442 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -631,6 +631,10 @@ def add_invertible_flag(flag: str, '--find-occurrences', metavar='CLASS.MEMBER', dest='special-opts:find_occurrences', help="Print out all usages of a class member (experimental)") + add_invertible_flag('--merge-stub-into-src', default=False, strict_flag=False, + help="when a stub and source file is in the same folder with same name " + "merge the stub file into the source file, and lint the source file", + group=other_group) if server_options: # TODO: This flag is superfluous; remove after a short transition (2018-03-16) diff --git a/mypy/modulefinder.py b/mypy/modulefinder.py index 54caf7da3d149..73d023f7d6a21 100644 --- a/mypy/modulefinder.py +++ b/mypy/modulefinder.py @@ -10,7 +10,7 @@ import subprocess import sys -from typing import Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union MYPY = False if MYPY: @@ -41,17 +41,32 @@ class BuildSource: """A single source file.""" def __init__(self, path: Optional[str], module: Optional[str], - text: Optional[str], base_dir: Optional[str] = None) -> None: + text: Optional[str], base_dir: Optional[str] = None, + merge_with: Optional['BuildSource'] = None) -> None: self.path = path # File where it's found (e.g. 'xxx/yyy/foo/bar.py') self.module = module or '__main__' # Module name (e.g. 'foo.bar') self.text = text # Source code, if initially supplied, else None self.base_dir = base_dir # Directory where the package is rooted (e.g. 'xxx/yyy') + self.merge_with = merge_with def __repr__(self) -> str: return '' % (self.path, self.module, self.text is not None) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, BuildSource): + return False + return (self.path, self.module, self.text, self.base_dir, self.merge_with) == ( + other.path, + other.module, + other.text, + other.base_dir, + other.merge_with, + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) class FindModuleCache: """Module finder with integrated cache. @@ -232,13 +247,18 @@ def _find_module(self, id: str) -> Optional[str]: elif self.options and self.options.namespace_packages and fscache.isdir(base_path): near_misses.append(base_path) # No package, look for module. + paths = [] for extension in PYTHON_EXTENSIONS: path = base_path + extension if fscache.isfile_case(path): if verify and not verify_module(fscache, id, path): near_misses.append(path) continue - return path + paths.append(path) + if len(paths) == 1: + return paths[0] + elif len(paths) > 1: + return tuple(paths) # In namespace mode, re-check those entries that had 'verify'. # Assume search path entries xxx, yyy and zzz, and we're @@ -276,7 +296,10 @@ def find_modules_recursive(self, module: str) -> List[BuildSource]: module_path = self.find_module(module) if not module_path: return [] - result = [BuildSource(module_path, module, None)] + merge_with = None + if isinstance(module_path, tuple): + module_path = module_path + result = [BuildSource(module_path, module, None, merge_with=merge_with)] if module_path.endswith(('__init__.py', '__init__.pyi')): # Subtle: this code prefers the .pyi over the .py if both # exists, and also prefers packages over modules if both x/ diff --git a/mypy/options.py b/mypy/options.py index c498ad6fba504..2de8246a96df1 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -241,6 +241,8 @@ def __init__(self) -> None: # Don't properly free objects on exit, just kill the current process. self.fast_exit = False + self.merge_stub_into_src = False + def snapshot(self) -> object: """Produce a comparable snapshot of this Option""" # Under mypyc, we don't have a __dict__, so we need to do worse things. diff --git a/mypy/stub_src_merge.py b/mypy/stub_src_merge.py new file mode 100644 index 0000000000000..6c0e44b8fd4f3 --- /dev/null +++ b/mypy/stub_src_merge.py @@ -0,0 +1,404 @@ +import itertools + +from contextlib import contextmanager +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) + +from mypy.errors import Errors +from mypy.nodes import ( + AssignmentStmt, + CallExpr, + ClassDef, + Decorator, + EllipsisExpr, + FuncDef, + IfStmt, + ImportBase, + MemberExpr, + MypyFile, + NameExpr, + Node, + Statement, + SymbolTable, + TempNode, + TypeInfo, + Var, +) + +MYPY = False +if MYPY: + from typing import Type + + +class MergeFiles: + """this is a class that allows merge-ing a stub MypyFile into a source MypyFile""" + + def __init__( + self, src: MypyFile, stub: MypyFile, errors: Errors, src_xpath: str, src_id: str + ) -> None: + """ create the class that can perfrom the merge operation + + :param src: the source MypyFile + :param stub: the stub MypyFile + :param errors: the place to reports errors during the merge + :param src_xpath: reports errors for this file (should be the source file) + :param src_id: reports errors for this module (should be the source module) + """ + self.src = src # type: MypyFile + self.stub = stub # type: MypyFile + self.errors = errors # type: Errors + self.errors.set_file(src_xpath, src_id) + self._line = -1 # type: int + self._column = 0 # type: int + + def run(self) -> None: + """performs the merge operation""" + self.merge_symbol_tables(self.src.names, self.stub.names) + self.enrich_src_defs_with_stub_only() + + def report( + self, msg: str, line: Optional[int] = None, column: Optional[int] = None + ) -> None: + """report an error + + :param msg: the message to report + :param line: the line number where the message is at (by default use last \ + known position during the merge) + :param column: the column where the message is at (by default use last known) + """ + if line is None: + line = self._line + if column is None: + column = self._column + self.errors.report(line=line, message=msg, column=column) + + def merge_symbol_tables(self, src: SymbolTable, stub: SymbolTable) -> None: + """merge every element from the stub into the source file""" + for stub_symbol_name, stub_symbol in stub.items(): + if stub_symbol_name not in src: + src[stub_symbol_name] = stub_symbol + else: + src_symbol = src[stub_symbol_name] + if src_symbol is not stub_symbol: + self.merge_nodes(src_symbol.node, stub_symbol.node) + + @contextmanager + def report_from_element(self, src: Any) -> Generator[None, None, None]: + """changes the current report location to the given source location""" + old_line, old_column = self._line, self._column + line = getattr(src, "line", old_line) + column = getattr(src, "line", old_column) + self._line, self._column = line, column + try: + yield + finally: + self._line, self._column = old_line, old_column + + def merge_nodes(self, src: Any, stub: Any) -> None: + """merge a given stub symbol into a source symbol""" + with self.report_from_element(src): + node_type = type(src) + if node_type != type(stub): + self.report( + "conflict of src {} and stub:{} {} definition".format( + src.__class__.__qualname__, + stub.line, + stub.__class__.__qualname__, + ) + ) + elif isinstance(src, Var): + self.merge_variables(src, stub) + elif isinstance(src, AssignmentStmt): + self.merge_assignment(src, stub) + elif isinstance(src, FuncDef): + self.merge_func_definition(src, stub) + elif isinstance(src, TypeInfo): + self.merge_type_info(src, stub) + elif isinstance(src, Decorator): + self.merge_decorators(src, stub) + else: + raise RuntimeError( + "cannot merge {!r} with {!r}".format(src, stub) + ) # pragma: no cover + + @staticmethod + def merge_variables(src: Var, stub: Var) -> None: + """merge a variable definition""" + src.type = stub.type + + def merge_assignment(self, src: AssignmentStmt, stub: AssignmentStmt) -> None: + """merge class variables""" + # note here we don't need to check name matching as the class merge logic + # already made sure they match + src.type = stub.type + src.unanalyzed_type = stub.unanalyzed_type + stub_name = cast(str, self.simple_assignment_name(stub)) + self.check_no_explicit_default(stub.rvalue, stub, stub_name) + + def merge_func_definition(self, src: FuncDef, stub: FuncDef) -> None: + """merge a function definition""" + if src.arg_names == stub.arg_names: + src.type = stub.type + src.arg_kinds = stub.arg_kinds + src.unanalyzed_type = stub.unanalyzed_type + else: + self.report( + "arg conflict of src {} and stub (line {}) {}".format( + repr(src.arg_names), stub.line, repr(stub.arg_names) + ) + ) + + def merge_type_info(self, src: TypeInfo, stub: TypeInfo) -> None: + """merge a class definition""" + src.type_vars = stub.type_vars + src.metaclass_type = stub.metaclass_type + src.runtime_protocol = stub.runtime_protocol + self.merge_symbol_tables(src.names, stub.names) + src.defn.type_vars = stub.defn.type_vars + + stub_assigns, stub_funcs = self.collect_assign_and_funcs(stub.defn.defs.body) + # merge stub entries into source + for entry in src.defn.defs.body: + if isinstance(entry, FuncDef): + self.merge_class_func_def(entry, stub_assigns, stub_funcs) + elif isinstance(entry, AssignmentStmt): + self.merge_class_assignment(entry, stub_assigns) + elif isinstance(entry, Decorator): + self.merge_class_decorator(entry, stub_funcs) + + # report extra stub entries + for k, v in stub_assigns.items(): + self.report("no source for assign {} @stub:{}".format(k, v.line)) + for f_k, f_v in stub_funcs.items(): + self.report("no source for func {} @stub:{}".format(f_k, f_v.line)) + + def merge_class_decorator( + self, src: Decorator, stub_funcs: Dict[str, Union[Decorator, FuncDef]] + ) -> None: + """merge class decorator, if stub not found""" + name = src.func.name() + if name in stub_funcs: + self.merge_nodes(src, stub_funcs[name]) + del stub_funcs[name] + + def merge_class_assignment( + self, src: AssignmentStmt, stubs: Dict[str, AssignmentStmt] + ) -> None: + """merge class variables""" + src_name = self.simple_assignment_name(src) + if src_name is not None: + if src_name in stubs: + self.merge_nodes(src, stubs[src_name]) + del stubs[src_name] + + def simple_assignment_name( + self, + node: AssignmentStmt, + l_value_type: 'Type[Union[NameExpr, MemberExpr]]' = NameExpr, + report: bool = True, + ) -> Optional[str]: + if len(node.lvalues) == 1: + l_value = node.lvalues[0] + if isinstance(l_value, l_value_type): + return cast(Union[NameExpr, MemberExpr], l_value).name + else: + if report: + self.report( + "l-values must be simple name expressions, is {}".format( + type(l_value).__qualname__ + ) + ) + elif report: # pragma: no cover + # TODO: how can we have more than one l-value in an assignment + self.report("assignment has more than one l-values") # pragma: no cover + return None + + def merge_class_func_def( + self, + src: FuncDef, + stub_assigns: Dict[str, AssignmentStmt], + stub_funcs: Dict[str, Union[Decorator, FuncDef]], + ) -> None: + """merge the function nodes and if it's a class constructor try to enrich + self assignments from the class variable type hints + """ + name = src.name() + if name in stub_funcs: + self.merge_nodes(src, stub_funcs[name]) + del stub_funcs[name] + if name == "__init__": + for init_part in src.body.body: + # we support either direct assignment, or assignment within + # if statements + if isinstance(init_part, AssignmentStmt): + self.enrich_src_assign_from_stub(init_part, stub_assigns) + elif isinstance(init_part, IfStmt): + for b in init_part.body: + for part in b.body: + if isinstance(part, AssignmentStmt): + self.enrich_src_assign_from_stub(part, stub_assigns) + + def collect_assign_and_funcs( + self, body: Iterable[Node] + ) -> Tuple[Dict[str, AssignmentStmt], Dict[str, Union[Decorator, FuncDef]]]: + """collect assignments and stubs from body""" + funcs = {} # type: Dict[str, Union[Decorator, FuncDef]] + assigns = {} # type: Dict[str, AssignmentStmt] + for entry in body: + if isinstance(entry, AssignmentStmt): + name = self.simple_assignment_name(entry) + if name is not None: + assigns[name] = entry + elif isinstance(entry, Decorator): + funcs[entry.func.name()] = entry + elif isinstance(entry, FuncDef): + funcs[entry.name()] = entry + return assigns, funcs + + def merge_decorators(self, src: Decorator, stub: Decorator) -> None: + """merge decorators above functions""" + for l, r in itertools.zip_longest(src.decorators, stub.decorators): + if type(l) != type(r): + self.report( + "conflict of src {} and stub {} decorator".format( + l.__class__.__qualname__, r.__class__.__qualname__ + ) + ) + break + if isinstance(l, NameExpr): + if not self.decorator_name_check(l, r): + break + elif isinstance(l, CallExpr): + if isinstance(l.callee, NameExpr) and isinstance(r.callee, NameExpr): + if not self.decorator_name_check(l.callee, r.callee): + break + self.decorator_argument_checks(l, r) + else: + self.merge_nodes(src.func, stub.func) + + def decorator_argument_checks(self, src: CallExpr, stub: CallExpr) -> None: + """check decorator arguments""" + for l_arg, r_arg in itertools.zip_longest(src.arg_names, stub.arg_names): + if l_arg != r_arg: + self.report( + "conflict of src {} and stub {} decorator argument name".format( + l_arg, r_arg + ), + line=src.line, + ) + break + for name, default_node in zip(stub.arg_names, stub.args): + self.check_no_explicit_default(default_node, src, cast(str, name)) + + def check_no_explicit_default( + self, default_node: Node, node: Node, name: str + ) -> None: + """check that no default value is set for this node""" + if not isinstance(node, TempNode) and not isinstance( + default_node, EllipsisExpr + ): + self.report( + ( + "stub should not contain default value, {} has {}".format( + name, type(default_node).__name__ + ) + ) + ) + + def decorator_name_check(self, src: NameExpr, stub: NameExpr) -> bool: + """check if the decorator name from source and stub match + + :return: True if the names match + """ + if src.name != stub.name: + self.report( + "conflict of src {} and stub {} decorator name".format( + src.name, stub.name + ) + ) + return False + return True + + def enrich_src_assign_from_stub( + self, src_assign: AssignmentStmt, stub_assign: Dict[str, AssignmentStmt] + ) -> bool: + """try to match a source assignment against existing stub assignment + + :return: True, if we found a matching stub assignment + """ + name = self.simple_assignment_name( + src_assign, l_value_type=MemberExpr, report=False + ) + if name is not None: + if name in stub_assign: + src_assign.type = stub_assign[name].type + src_assign.unanalyzed_type = stub_assign[name].unanalyzed_type + del stub_assign[name] + return True + else: + self.report("no stub definition for class member {}".format(name)) + return False + + def enrich_src_defs_with_stub_only(self) -> None: + """ + There are definitions that are needed to evaluate the source, which are not + present in the source file, only the stub: + - imports from the stub file (this help resolve those types) + - type aliases (take form of assignment) + - protocol definitions (in form of a class definition) + + Here we copy them over into the source definitions. + """ + src_definitions = self.source_definitions + + stub_definitions = [] # type: List[Statement] + for i in self.stub.defs: + if isinstance(i, ImportBase): + stub_definitions.append(i) + elif isinstance(i, ClassDef): + if i.name not in src_definitions: + stub_definitions.append(i) + elif isinstance(i, AssignmentStmt) and len(i.lvalues) == 1: + entry = i.lvalues[0] + if isinstance(entry, NameExpr): + name = entry.name + if name not in src_definitions: + # could be a type alias we add this + stub_definitions.append(i) + else: + # merge it into the source data + src_assignment = src_definitions[name] + if isinstance(src_assignment, AssignmentStmt): + src_assignment.type = i.type + src_assignment.unanalyzed_type = i.unanalyzed_type + del src_definitions[name] + + # stub imports are available for source + self.src.imports.extend(self.stub.imports) + + # we insert at start the stub definitions, this is important so source + # definition evaluation have all stub definitions available + self.src.defs = stub_definitions + self.src.defs + + @property + def source_definitions(self) -> Dict[str, Union[ClassDef, AssignmentStmt]]: + """collect source definitions that influence the symbol table (without imports) + """ + src_definitions = {} # type: Dict[str, Union[ClassDef, AssignmentStmt]] + for d in self.src.defs: + if isinstance(d, ClassDef): + src_definitions[d.name] = d + elif isinstance(d, AssignmentStmt): + for l in d.lvalues: + if isinstance(l, NameExpr): + src_definitions[l.name] = d + return src_definitions diff --git a/mypy/test/test_merge_stub_src.py b/mypy/test/test_merge_stub_src.py new file mode 100644 index 0000000000000..53a632e22608f --- /dev/null +++ b/mypy/test/test_merge_stub_src.py @@ -0,0 +1,194 @@ +"""Type checker test cases""" + +import os +import re +import sys +from typing import Dict, List, Set, Tuple + +from mypy import build +from mypy.build import Graph +from mypy.errors import CompileError +from mypy.modulefinder import BuildSource, SearchPaths, FindModuleCache +from mypy.test.config import test_temp_dir, test_data_prefix +from mypy.test.data import DataDrivenTestCase, DataSuite +from mypy.test.helpers import ( + assert_string_arrays_equal, + normalize_error_messages, + update_testcase_output, + parse_options, +) + +# List of files that contain test case descriptions. +typecheck_files = ["check-stub-src-merge.test"] + + +class TypeCheckSuite(DataSuite): + files = typecheck_files + + def run_case(self, testcase: DataDrivenTestCase) -> None: + self.run_case_once(testcase) + + def run_case_once(self, testcase: DataDrivenTestCase) -> None: + original_program_text = "\n".join(testcase.input) + options = parse_options(original_program_text, testcase, 0) + + if any(p.startswith(os.path.join("tmp", "out")) for p, _ in testcase.files): + direct_asts = self.load_data( + options, original_program_text, testcase, to_merge=False + ) + else: + direct_asts = None + merged_asts = self.load_data( + options, original_program_text, testcase, to_merge=True + ) + if direct_asts is not None: + assert merged_asts.keys() == direct_asts.keys() + for key in merged_asts: + merged = merged_asts[key] + direct = direct_asts[key] + assert merged == direct + + def load_data(self, options, original_program_text, testcase, to_merge=False): + folder = "in" if to_merge else "out" + module_data = self.parse_module(original_program_text, folder) + options.merge_stub_into_src = to_merge + options.use_builtins_fixtures = True + options.show_traceback = True + options.strict_optional = True + sources = [] + iterator = iter(module_data) + no_next = None, None, None + module_name, program_path, program_text = next(iterator, no_next) + while module_name: + source = BuildSource(program_path, module_name, program_text) + sources.append(source) + module_name, program_path, program_text = next(iterator, no_next) + if options.merge_stub_into_src is True: + if source.path.endswith(".pyi") and source.module == module_name: + src = BuildSource( + program_path, module_name, program_text, merge_with=source + ) + sources[-1] = src + module_name, program_path, program_text = next(iterator, no_next) + plugin_dir = os.path.join(test_data_prefix, "plugins") + sys.path.insert(0, plugin_dir) + res = None + try: + res = build.build( + sources=sources, + options=options, + alt_lib_path=os.path.join(test_temp_dir), + ) + a = res.errors + except CompileError as e: + a = e.messages + finally: + assert sys.path[0] == plugin_dir + del sys.path[0] + if to_merge: + if testcase.normalize_output: + a = normalize_error_messages(a) + msg = "Unexpected type checker output ({}, line {})" + output = testcase.output + if output != a and testcase.config.getoption("--update-data", False): + update_testcase_output(testcase, a) + assert_string_arrays_equal( + output, a, msg.format(testcase.file, testcase.line) + ) + if res: + if options.cache_dir != os.devnull: + self.verify_cache(module_data, res.errors, res.manager, res.graph) + ast_mod_to_graph = {} + for source in sources: + full_ast_str = str(res.graph[source.module].tree) + repr_path = source.path.replace(os.sep, "/") + file_path_regex = re.compile( + r"\s+{}\s+^".format(re.escape(repr_path)), re.MULTILINE + ) + ast_with_no_file_path = file_path_regex.sub("\n", full_ast_str) + line_nr = re.compile(r"(\w+):\d+") + ast_str = line_nr.sub(r"\1", ast_with_no_file_path) + ast_mod_to_graph[source.module] = ast_str + return ast_mod_to_graph + + def verify_cache( + self, + module_data: List[Tuple[str, str, str]], + a: List[str], + manager: build.BuildManager, + graph: Graph, + ) -> None: + # There should be valid cache metadata for each module except + # for those that had an error in themselves or one of their + # dependencies. + error_paths = self.find_error_message_paths(a) + busted_paths = { + m.path for id, m in manager.modules.items() if graph[id].transitive_error + } + modules = self.find_module_files(manager) + modules.update({module_name: path for module_name, path, text in module_data}) + missing_paths = self.find_missing_cache_files(modules, manager) + # We would like to assert error_paths.issubset(busted_paths) + # but this runs into trouble because while some 'notes' are + # really errors that cause an error to be marked, many are + # just notes attached to other errors. + assert ( + error_paths or not busted_paths + ), "Some modules reported error despite no errors" + if not missing_paths == busted_paths: + raise AssertionError( + "cache data discrepancy %s != %s" % (missing_paths, busted_paths) + ) + + def find_error_message_paths(self, a: List[str]) -> Set[str]: + hits = set() + for line in a: + m = re.match(r"([^\s:]+):(\d+:)?(\d+:)? (error|warning|note):", line) + if m: + p = m.group(1) + hits.add(p) + return hits + + def find_module_files(self, manager: build.BuildManager) -> Dict[str, str]: + modules = {} + for id, module in manager.modules.items(): + modules[id] = module.path + return modules + + def find_missing_cache_files( + self, modules: Dict[str, str], manager: build.BuildManager + ) -> Set[str]: + ignore_errors = True + missing = {} + for id, path in modules.items(): + meta = build.find_cache_meta(id, path, manager) + if not build.validate_meta(meta, id, path, ignore_errors, manager): + missing[id] = path + return set(missing.values()) + + def parse_module(self, program_text: str, folder) -> List[Tuple[str, str, str]]: + """Return a list of tuples (module name, file name, program text). """ + m = re.search( + r"# modules: ([a-zA-Z0-9_. ]+)$", program_text, flags=re.MULTILINE + ) + if m: + in_path = os.path.join(test_temp_dir, folder) + # The test case wants to use a non-default main + # module. Look up the module and give it as the thing to + # analyze. + module_names = m.group(1) + out = [] + search_paths = SearchPaths((in_path,), (), (), ()) + cache = FindModuleCache(search_paths) + for module_name in module_names.split(" "): + path = cache.find_module(module_name) + assert path is not None, "Can't find ad hoc case file {} in {}".format( + in_path, module_names + ) + for p in path if isinstance(path, tuple) else (path,): + with open(p, encoding="utf8") as f: + program_text = f.read() + out.append((module_name, p, program_text)) + return out + else: + raise ValueError("no modules defined") diff --git a/mypy/test/test_source_finder.py b/mypy/test/test_source_finder.py new file mode 100644 index 0000000000000..c6019dac8dbe4 --- /dev/null +++ b/mypy/test/test_source_finder.py @@ -0,0 +1,145 @@ +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import pytest # type: ignore # no pytest in typeshed +from py._path.local import LocalPath # type: ignore # no py in typeshed + +from mypy.build import BuildSource +from mypy.find_sources import create_source_list +from mypy.options import Options + +MergeFinder = Callable[ + [Dict[str, Any], bool, Optional[Sequence[str]]], Tuple[Path, List[BuildSource]] +] + + +def create_files(folder: Path, content: Dict[str, Any]) -> None: + if not folder.exists(): + folder.mkdir() + for key, value in content.items(): + if key.endswith(".py") or key.endswith(".pyi"): + with open(str(folder / key), "wt") as file_handler: + file_handler.write(value) + elif isinstance(value, dict): + create_files(folder / key, value) + + +@pytest.fixture() +def merge_finder( # type: ignore # cannot follow this import + tmpdir: LocalPath, monkeypatch: Any +) -> MergeFinder: + def _checker( + content: Dict[str, Any], merge: bool, args: Optional[Sequence[str]] = None + ) -> Tuple[Path, List[BuildSource]]: + monkeypatch.chdir(tmpdir) + options = Options() + options.merge_stub_into_src = merge + test_dir = str(tmpdir) + create_files(Path(test_dir), content) + targets = create_source_list(files=[test_dir] if args is None else args, options=options) + return Path(str(tmpdir)), targets + + return _checker + + +def test_source_finder_merge(merge_finder: MergeFinder) -> None: + base_dir, found = merge_finder({"a.py": "", "a.pyi": ""}, True, None) + assert found == [ + BuildSource( + path=str(base_dir / "a.py"), + base_dir=str(base_dir), + module="a", + text=None, + merge_with=BuildSource( + path=str(base_dir / "a.pyi"), base_dir=str(base_dir), module="a", text=None + ), + ) + ] + + +def test_source_finder_merge_sub_folder(merge_finder: MergeFinder) -> None: + base_dir, found = merge_finder( + {"pkg": {"a.py": "", "a.pyi": "", "__init__.py": ""}}, True, None + ) + assert found == [ + BuildSource( + path=str(base_dir / "pkg" / "__init__.py"), + base_dir=str(base_dir), + module="pkg", + text=None, + ), + BuildSource( + path=str(base_dir / "pkg" / "a.py"), + base_dir=str(base_dir), + module="pkg.a", + text=None, + merge_with=BuildSource( + path=str(base_dir / "pkg" / "a.pyi"), + base_dir=str(base_dir), + module="pkg.a", + text=None, + ), + ), + ] + + +def test_source_finder_no_merge(merge_finder: MergeFinder) -> None: + base_dir, found = merge_finder({"a.py": "", "a.pyi": ""}, False, None) + assert found == [ + BuildSource(path=str(base_dir / "a.pyi"), base_dir=str(base_dir), module="a", text=None) + ] + + +@pytest.mark.parametrize("merge", [True, False]) +def test_source_finder_merge_just_source(merge_finder: MergeFinder, merge: bool) -> None: + base_dir, found = merge_finder({"a.py": ""}, merge, None) + assert found == [ + BuildSource(path=str(base_dir / "a.py"), base_dir=str(base_dir), module="a", text=None) + ] + + +@pytest.mark.parametrize("merge", [True, False]) +def test_source_finder_merge_just_stub(merge_finder: MergeFinder, merge: bool) -> None: + base_dir, found = merge_finder({"a.pyi": ""}, merge, None) + assert found == [ + BuildSource(path=str(base_dir / "a.pyi"), base_dir=str(base_dir), module="a", text=None) + ] + + +def test_source_finder_matching_exists(merge_finder: MergeFinder) -> None: + base_dir, found = merge_finder({"a.py": "", "a.pyi": ""}, True, ["a.py"]) + assert found == [ + BuildSource( + path="a.py", + base_dir=".", + module="a", + text=None, + merge_with=BuildSource(path="a.pyi", base_dir=".", module="a", text=None), + ) + ] + + +def test_source_finder_matching_exists_stub_specified(merge_finder: MergeFinder) -> None: + base_dir, found = merge_finder({"a.py": "", "a.pyi": ""}, True, ["a.py", "a.pyi"]) + assert found == [ + BuildSource( + path="a.py", + base_dir=".", + module="a", + text=None, + merge_with=BuildSource(path="a.pyi", base_dir=".", module="a", text=None), + ) + ] + + +def test_source_finder_matching_exists_stub_specified_first(merge_finder: MergeFinder) -> None: + base_dir, found = merge_finder({"a.py": "", "a.pyi": ""}, True, ["a.pyi", "a.py"]) + assert found == [ + BuildSource( + path="a.py", + base_dir=".", + module="a", + text=None, + merge_with=BuildSource(path="a.pyi", base_dir=".", module="a", text=None), + ) + ] diff --git a/test-data/unit/check-stub-src-merge.test b/test-data/unit/check-stub-src-merge.test new file mode 100644 index 0000000000000..d90293ece25f4 --- /dev/null +++ b/test-data/unit/check-stub-src-merge.test @@ -0,0 +1,333 @@ +[case stub_src_type_does_not_match] +# modules: a +[file in/a.py] +a = 1 # E: conflict of src Var and stub:1 FuncDef definition +[file in/a.pyi] +def a() -> None: ... + +[case module_func_arg_conflict] +# modules: a +[file in/a.py] +def a(arg): ... # E: arg conflict of src ['arg'] and stub (line 1) ['arg1'] +[file in/a.pyi] +def a(arg1) -> None: ... + +[case module_func] +# modules: a +[file in/a.py] +def fancy_add(a, b = None): + pass +[file in/a.pyi] +from typing import Union +def fancy_add(a : int, b : Union[None, int] = ...) -> int: ... +[file out/a.py] +from typing import Union +def fancy_add(a : int, b : Union[None, int] = None) -> int: + pass + +[case module_func_reverse_union] +# modules: a +[file in/a.py] +def fancy_add(a, b = None): + pass +[file in/a.pyi] +from typing import Union +def fancy_add(a : int, b : Union[int, None] = ...) -> int: ... +[file out/a.py] +from typing import Union +def fancy_add(a : int, b : Union[int, None] = None) -> int: + pass + +[case module_variable] +# modules: a +[file in/a.py] +VERSION = ('1', '2', 3) +[file in/a.pyi] +from typing import Tuple +VERSION : Tuple[str, str, int] = ... +[file out/a.py] +from typing import Tuple +VERSION : Tuple[str, str, int] = ('1', '2', 3) + + +[case class_variable_default] +# modules: a +[file in/a.py] +class A: + b = 1 +[file in/a.pyi] +class A: + b : int = ... +[file out/a.py] +class A: + b : int = 1 + + +[case class_init_self_can_be_defined_at_lass_level] +# modules: a +[file in/a.py] +class A: + def __init__(self): + self.a = 's' +[file in/a.pyi] +class A: + a : str = ... + def __init__(self) -> None: ... +[file out/a.py] +class A: + def __init__(self) -> None: + self.a : str = 's' + + +[case class_init_not_defined] +# modules: a +[file in/a.py] +class A: # E: no stub definition for class member a + def __init__(self): + a = 1 + self.a = a +[file in/a.pyi] +class A: + def __init__(self) -> None: ... + +[case class_init_self_can_be_defined_at_class_level_inside_if] +# modules: a +[file in/a.py] +class A: + def __init__(self): + if True: + self.a = 0 + else: + self.a = 1 +[file in/a.pyi] +class A: + a : int = ... + def __init__(self) -> None: ... +[file out/a.py] +class A: + def __init__(self) -> None: + if True: + self.a : int = 0 + else: + self.a = 1 + +[case class_decorator] +# modules: a +[builtins fixtures/classmethod.pyi] +[file in/a.py] +class A: + @classmethod + def a(cls): + pass +[file in/a.pyi] +class A: + @classmethod + def a(cls) -> None: ... +[file out/a.py] +class A: + @classmethod + def a(cls) -> None: + pass + +[case class_decorator_no_source_decorator] +# modules: a +[builtins fixtures/classmethod.pyi] +[file in/a.py] +class A: # E: no source for func a @stub:2 + pass +[file in/a.pyi] +class A: + @classmethod + def a(cls) -> None: ... + +[case class_decorator_no_source_assignment] +# modules: a +[file in/a.py] +class A: # E: no source for assign a @stub:2 + pass +[file in/a.pyi] +class A: + a : int = ... + +[case module_variable_complex_not_supported] +# modules: a +[file in/a.py] +class A: # E: l-values must be simple name expressions, is TupleExpr + a, b = 1, 2 +[file in/a.pyi] +class A: + a, b = ..., ... + +[case class_decorator_no_source_func] +# modules: a +[file in/a.py] +class A: # E: no source for func a @stub:2 + pass +[file in/a.pyi] +class A: + def a(cls) -> None: ... + +[case variable_module_level] +# modules: a +[file in/a.py] +a = 1 +[file in/a.pyi] +a : int = ... +[file out/a.py] +a : int = 1 + +[case decorator_func] +# modules: a +[file in/a.py] +def d(p, r): + def t(func): + def w(b): + return str(func(int(b) + p + r)) + return w + return t +@d(p=1, r=2) +def a(b): + return b + 1 +a("1") +[file in/a.pyi] +from typing import Callable +def d(p: int, r: int) -> Callable[[Callable[[int], int]], Callable[[str], str]]: ... +@d(p= ..., r= ...) +def a(b: int) -> int: ... +[file out/a.py] +from typing import Callable +def d(p: int, r: int) -> Callable[[Callable[[int], int]], Callable[[str], str]]: + def t(func): + def w(b): + return str(func(int(b) + p + r)) + return w + return t +@d(p=1, r=2) +def a(b: int) -> int: + return b + 1 +a("1") + +[case decorator_func_arg_not_ellipse] +# modules: a +[file in/a.py] +def d(p, r): + def t(func): + def w(b): + return str(func(int(b) + p + r)) + return w + return t +@d(p=1, r=2) # E: stub should not contain default value, p has IntExpr # E: stub should not contain default value, r has NameExpr +def a(b): + return b + 1 +[file in/a.pyi] +from typing import Callable +def d(p: int, r: int) -> Callable[[Callable[[int], int]], Callable[[str], str]]: ... +@d(p= 1, r= None) +def a(b: int) -> int: ... + +[case decorator_func_args_keys_differ] +# modules: a +[file in/a.py] +def d(p, r): + def t(func): + def w(b): + return str(func(int(b) + p + r)) + return w + return t +@d(p=1, v=None) # E: Unexpected keyword argument "v" for "d" # E: conflict of src v and stub r decorator argument name +def a(): + pass +[file in/a.pyi] +@d(p=..., r=...) +def a() -> None: ... + +[case decorator_name_does_not_match] +# modules: a +[file in/a.py] +def d(p, r): + def t(func): + def w(b): + return str(func(int(b) + p + r)) + return w + return t +def de(p, r): + return d(p, r) +@de(p=1, r=1) # E: conflict of src de and stub d decorator name +def a(b: int) -> int: ... +[file in/a.pyi] +from typing import Callable +def d(p: int, r: int) -> Callable[[Callable[[int], int]], Callable[[str], str]]: ... +def de(p: int, r: int) -> Callable[[Callable[[int], int]], Callable[[str], str]]: ... +@d(p=..., r=...) +def a(b: int) -> int: ... + +[case decorator_source_less] +# modules: a +[file in/a.py] +def d(func): + def w(): + return func() + return w +def de(func): + return d(func) +@d # E: conflict of src NoneType and stub NameExpr decorator +def a() -> None: + return None +[file in/a.pyi] +from typing import Callable +def d(func: Callable[[], None]) -> Callable[[], None]: ... +def de(func: Callable[[], None]) -> Callable[[], None]: ... +@d +@de +def a() -> None: ... + +[case decorator_stub_less] +# modules: a +[file in/a.py] +def d(func): + def w(): + return func() + return w +def de(func): + return d(func) +@d # E: conflict of src NameExpr and stub NoneType decorator +@de +def a() -> None: + return None +[file in/a.pyi] +from typing import Callable +def d(func: Callable[[], None]) -> Callable[[], None]: ... +def de(func: Callable[[], None]) -> Callable[[], None]: ... +@d +def a() -> None: ... + +[case decorator_expr_stub_not_ellipsis] +# modules: a +[file in/a.py] +def d(p, r): + def t(func): + def w(b): + return str(func(int(b) + p + r)) + return w + return t +@d(p=1, r=2) # E: stub should not contain default value, p has IntExpr +def a(b): + return b + 1 +[file in/a.pyi] +from typing import Callable +def d(p: int, r: int) -> Callable[[Callable[[int], int]], Callable[[str], str]]: ... +@d(p= 1, r= ...) +def a(b: int) -> int: ... + + +[case relative_import] +# modules: t2 p p.t +[file in/p/__init__.pyi] +[file in/p/t.py] +def f(x): pass +[file in/p/t.pyi] +def f(x: int) -> None: ... +[file in/t2.py] +from p.t import f + +