Skip to content

Commit

Permalink
merge stubs into source first implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborbernat committed Jun 7, 2018
1 parent 86df809 commit 848de98
Show file tree
Hide file tree
Showing 5 changed files with 513 additions and 22 deletions.
30 changes: 29 additions & 1 deletion mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from mypy.options import Options
from mypy.parse import parse
from mypy.stats import dump_type_stats
from mypy.stub_src_merge import MergeFiles
from mypy.types import Type
from mypy.version import __version__
from mypy.plugin import Plugin, DefaultPlugin, ChainedPlugin
Expand Down Expand Up @@ -97,11 +98,13 @@ def __init__(self, manager: 'BuildManager', graph: Graph) -> None:

class BuildSource:
def __init__(self, path: Optional[str], module: Optional[str],
text: Optional[str], base_dir: Optional[str] = None) -> None:
text: Optional[str], base_dir: Optional[str] = None,
merge_with: Optional['BuildSource'] = None) -> None:
self.path = path
self.module = module or '__main__'
self.text = text
self.base_dir = base_dir
self.merge_with = merge_with

def __repr__(self) -> str:
return '<BuildSource path=%r module=%r has_text=%s>' % (self.path,
Expand All @@ -122,6 +125,8 @@ def __init__(self, sources: List[BuildSource]) -> None:
self.source_text_present = True
elif source.path:
self.source_paths.add(source.path)
if source.merge_with and source.merge_with.path:
self.source_paths.add(source.merge_with.path)
else:
self.source_modules.add(source.module)

Expand Down Expand Up @@ -2236,6 +2241,22 @@ def generate_unused_ignore_notes(self) -> None:
self.verify_dependencies(suppressed_only=True)
self.manager.errors.generate_unused_ignore_notes(self.xpath)

def merge_with(self, state: 'State', errors: Errors) -> None:
self.ancestors = list(set(self.ancestors or []) | set(state.ancestors or []))
self.child_modules = list(set(self.child_modules) | set(state.child_modules))
self.dependencies = list(set(self.dependencies) | set(state.dependencies))

dep_line_map = {k: -v for k, v in state.dep_line_map.items()}
dep_line_map.update(self.dep_line_map)
self.dep_line_map = dep_line_map

priorities = {k: -v for k, v in state.priorities.items()}
priorities.update(self.priorities)
self.priorities = priorities

self.dep_line_map.update({k: -v for k, v in state.dep_line_map.items()})
MergeFiles(self.tree, state.tree, errors, self.xpath, self.id).run()


# Module import and diagnostic glue

Expand Down Expand Up @@ -2538,6 +2559,13 @@ def load_graph(sources: List[BuildSource], manager: BuildManager,
try:
st = State(id=bs.module, path=bs.path, source=bs.text, manager=manager,
root_source=True)
if bs.merge_with:
mw = bs.merge_with
src_st = State(id=mw.module, path=mw.path, source=mw.text,
manager=manager, root_source=True)
src_st.parse_file()
src_st.merge_with(st, manager.errors)
st = src_st
except ModuleNotFound:
continue
if st.id in graph:
Expand Down
58 changes: 37 additions & 21 deletions mypy/find_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def create_source_list(files: Sequence[str], options: Options,
return targets


def keyfunc(name: str) -> Tuple[int, str]:
PY_MAP = {k: i for i, k in enumerate(PY_EXTENSIONS)}


def keyfunc(name: str) -> Tuple[str, int]:
"""Determines sort order for directory listing.
The desirable property is foo < foo.pyi < foo.py.
"""
base, suffix = os.path.splitext(name)
for i, ext in enumerate(PY_EXTENSIONS):
if suffix == ext:
return (i, base)
return (-1, name)
return base, PY_MAP.get(suffix, -1)


class SourceFinder:
Expand All @@ -76,22 +76,38 @@ def expand_dir(self, arg: str, mod_prefix: str = '') -> List[BuildSource]:
sources.append(BuildSource(f, mod_prefix.rstrip('.'), None, base_dir))
names = self.fscache.listdir(arg)
names.sort(key=keyfunc)
for name in names:
path = os.path.join(arg, name)
if self.fscache.isdir(path):
sub_sources = self.expand_dir(path, mod_prefix + name + '.')
if sub_sources:
seen.add(name)
sources.extend(sub_sources)
else:
base, suffix = os.path.splitext(name)
if base == '__init__':
continue
if base not in seen and '.' not in base and suffix in PY_EXTENSIONS:
seen.add(base)
src = BuildSource(path, mod_prefix + base, None, base_dir)
sources.append(src)
return sources
name_iter = iter(names)
try:
name = next(name_iter, None)
while name is not None:
path = os.path.join(arg, name)
if self.fscache.isdir(path):
sub_sources = self.expand_dir(path, mod_prefix + name + '.')
if sub_sources:
seen.add(name)
sources.extend(sub_sources)
name = next(name_iter)
else:
base, suffix = os.path.splitext(name)
name = next(name_iter, None)
if base == '__init__':
continue
if base not in seen and '.' not in base and suffix in PY_EXTENSIONS:
seen.add(base)
if name is None:
next_base, next_suffix = None, None
else:
next_base, next_suffix = os.path.splitext(name)
merge_with = None
if next_base is not None and next_base == base:
merge_with = BuildSource(os.path.join(arg, name),
mod_prefix + next_base, None, base_dir)
src = BuildSource(path, mod_prefix + base, None, base_dir,
merge_with=merge_with)
sources.append(src)
return sources
except StopIteration:
return sources

def crawl_up(self, arg: str) -> Tuple[str, str]:
"""Given a .py[i] filename, return module and base directory
Expand Down
244 changes: 244 additions & 0 deletions mypy/stub_src_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import itertools
from typing import Any, Dict, List, Optional, Union

from mypy.errors import Errors
from mypy.nodes import (
AssignmentStmt,
CallExpr,
ClassDef,
Decorator,
EllipsisExpr,
FuncDef,
IfStmt,
ImportBase,
ImportedName,
MemberExpr,
MypyFile,
NameExpr,
Statement,
SymbolTable,
SymbolTableNode,
TypeInfo,
Var
)


class MergeFiles:
def __init__(
self, src: MypyFile, stub: MypyFile, errors: Errors, src_xpath: str, src_id: str
):
self.src = src
self.stub = stub
self.errors = errors
self.errors.set_file(src_xpath, src_id)
self._line = -1 # type: int

def report(
self, msg: str, line: Optional[int] = None, column: Optional[int] = None
):
if line is None:
line = self._line
if column is None:
column = 0
self.errors.report(line=line, message=msg, column=column)

def run(self):
src = self.src
stub = self.stub
src.imports.extend(stub.imports)

self.merge_names(src.names, stub.names)

src_names = {} # type: Dict[str, Union[ClassDef, AssignmentStmt]]
for d in src.defs:
if isinstance(d, ClassDef):
src_names[d.name] = d
elif isinstance(d, AssignmentStmt):
for l in d.lvalues:
if isinstance(l, NameExpr):
src_names[l.name] = d

defs = [] # type: List[Statement]
for i in stub.defs:
if isinstance(i, ImportBase):
defs.append(i)
elif isinstance(i, ClassDef):
if i.name not in src_names:
defs.append(i)
elif isinstance(i, AssignmentStmt) and len(i.lvalues) == 1:
entry = i.lvalues[0]
if isinstance(entry, NameExpr):
name = entry.name
if name not in src_names:
defs.append(i) # could be a type alias we add this
else:
# merge it into the source data
src_assignment = src_names[name]
if isinstance(src_assignment, AssignmentStmt):
src_assignment.type = i.type
src_assignment.unanalyzed_type = i.unanalyzed_type
del src_names[name]
defs.extend(src.defs)
src.defs = defs

def merge_names(self, src: "SymbolTable", stub: "SymbolTable") -> None:
for key, value in stub.items():
if key not in src:
src[key] = value
else:
main_value = src.get(key)
if main_value is not value:
self.merge_symbols(main_value, value)

def merge_symbols(self, src: Any, stub: Any) -> None:
t = type(src)
self._line = getattr(src, "line", -1)
if t != type(stub):
msg = "definition conflict of src {} and stub (line {}) {}".format(
src.__class__.__qualname__, stub.line, stub.__class__.__qualname__
)
self.report(msg)
return
if src is None or isinstance(src, ImportedName):
pass
elif isinstance(src, SymbolTableNode):
src.type_override = stub.type_override
self.merge_symbols(src.node, stub.node)
elif isinstance(src, Var):
self._merge_variables(src, stub)
elif isinstance(src, FuncDef):
self._merge_func_definition(src, stub)
elif isinstance(src, TypeInfo):
self._merge_type_info(src, stub)
elif isinstance(t, AssignmentStmt):
pass
elif isinstance(src, Decorator):
self._merge_decorators(src, stub)
else:
raise RuntimeError("cannot merge {!r} with {!r}".format(src, stub))

@staticmethod
def _merge_variables(src, stub):
src.type = stub.type

def _merge_func_definition(self, src, stub):
if src.arg_names == stub.arg_names:
src.type = stub.type
src.arg_kinds = stub.arg_kinds
src.unanalyzed_type = stub.unanalyzed_type
# merge(left.body, right.body) - does this makes sense ? - we cannot annotate this
else:
msg = "arg conflict of src {} and stub (line {}) {}".format(
repr(src.arg_names), stub.line, repr(stub.arg_names)
)
self.report(msg)

def _merge_type_info(self, src, stub):
src.type_vars = stub.type_vars
src.metaclass_type = stub.metaclass_type
src.runtime_protocol = stub.runtime_protocol
self.merge_names(src.names, stub.names)
src.defn.type_vars = stub.defn.type_vars
stub_body = stub.defn.defs.body
src_body = src.defn.defs.body
stub_func = {} # type: Dict[str, Union[Decorator, FuncDef]]
stub_assign = {} # type: Dict[str, AssignmentStmt]
for entry in stub_body:
if isinstance(entry, AssignmentStmt):
l_value = entry.lvalues
if len(l_value) != 1:
self.report(
"multiple l-values not supported in stub {}".format(entry.line)
)
stub_lvalue = l_value[0]
if isinstance(stub_lvalue, NameExpr):
stub_assign[stub_lvalue.name] = entry
elif isinstance(entry, Decorator):
stub_func[entry.func.name()] = entry
elif isinstance(entry, FuncDef):
stub_func[entry.name()] = entry
for entry in src_body:
if isinstance(entry, FuncDef):
name = entry.name()
if name in stub_func:
self.merge_symbols(entry, stub_func[name])
del stub_func[name]
if name == "__init__":
for init_part in entry.body.body:
if isinstance(init_part, AssignmentStmt):
self.class_level_stub_to_init_assign(init_part, stub_assign)
elif isinstance(init_part, IfStmt):
for b in init_part.body:
for part in b.body:
if isinstance(part, AssignmentStmt):
self.class_level_stub_to_init_assign(
part, stub_assign
)
elif isinstance(entry, AssignmentStmt):
if len(entry.lvalues) == 1:
stub_lvalue = entry.lvalues[0]
if isinstance(stub_lvalue, NameExpr):
name = stub_lvalue.name
if name in stub_assign:
self.merge_symbols(entry, stub_assign[name])
del stub_assign[name]
elif isinstance(entry, Decorator):
name = entry.func.name()
if name in stub_func:
self.merge_symbols(entry, stub_func[name])
del stub_func[name]
for k, v in stub_assign.items():
self.report("no source for assign {} @stub:{}".format(k, v.line))
for f_k, f_v in stub_func.items():
self.report("no source for func {} @{}".format(f_k, f_v.line))

def _merge_decorators(self, src: "Decorator", stub: "Decorator"):
for l, r in itertools.zip_longest(src.decorators, stub.decorators):
if type(l) != type(r):
self.report("decorator type conflict")
break
if isinstance(l, NameExpr):
if not self._decorator_name_match(l, r):
break
elif isinstance(l, CallExpr):
if isinstance(l.callee, NameExpr) and isinstance(r.callee, NameExpr):
if not self._decorator_name_match(l.callee, r.callee):
break
self._decorator_expr_arg_checks(l, r)
else:
self.merge_symbols(src.func, stub.func)

def _decorator_expr_arg_checks(self, l, r):
for l_arg, r_arg in itertools.zip_longest(l.arg_names, r.arg_names):
if l_arg != r_arg:
msg = "decorator arg name conflict {} vs {}".format(l_arg, r_arg)
self.report(msg, l.line)
break
for n, t in zip(r.arg_names, r.args):
if not isinstance(t, EllipsisExpr):
msg = (
"stub decorator should not contain default value"
", {} has {}".format(n, type(t).__name__)
)
self.report(msg, l.line)
break

def _decorator_name_match(self, l: NameExpr, r: NameExpr) -> bool:
if l.name != r.name:
msg = "decorator name conflict {} vs {}".format(l.name, r.name)
self.report(msg)
return False
return True

@staticmethod
def class_level_stub_to_init_assign(init_part, stub_assign):
if len(init_part.lvalues) == 1:
l_value = init_part.lvalues[0]
if isinstance(l_value, MemberExpr):
name = l_value.name
if name in stub_assign:
init_part.type = stub_assign[name].type
init_part.unanalyzed_type = stub_assign[name].unanalyzed_type
del stub_assign[name]
return True
return False
Loading

0 comments on commit 848de98

Please sign in to comment.