-
-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge stubs into source first implementation
- Loading branch information
1 parent
86df809
commit 848de98
Showing
5 changed files
with
513 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
import itertools | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from mypy.errors import Errors | ||
from mypy.nodes import ( | ||
AssignmentStmt, | ||
CallExpr, | ||
ClassDef, | ||
Decorator, | ||
EllipsisExpr, | ||
FuncDef, | ||
IfStmt, | ||
ImportBase, | ||
ImportedName, | ||
MemberExpr, | ||
MypyFile, | ||
NameExpr, | ||
Statement, | ||
SymbolTable, | ||
SymbolTableNode, | ||
TypeInfo, | ||
Var | ||
) | ||
|
||
|
||
class MergeFiles: | ||
def __init__( | ||
self, src: MypyFile, stub: MypyFile, errors: Errors, src_xpath: str, src_id: str | ||
): | ||
self.src = src | ||
self.stub = stub | ||
self.errors = errors | ||
self.errors.set_file(src_xpath, src_id) | ||
self._line = -1 # type: int | ||
|
||
def report( | ||
self, msg: str, line: Optional[int] = None, column: Optional[int] = None | ||
): | ||
if line is None: | ||
line = self._line | ||
if column is None: | ||
column = 0 | ||
self.errors.report(line=line, message=msg, column=column) | ||
|
||
def run(self): | ||
src = self.src | ||
stub = self.stub | ||
src.imports.extend(stub.imports) | ||
|
||
self.merge_names(src.names, stub.names) | ||
|
||
src_names = {} # type: Dict[str, Union[ClassDef, AssignmentStmt]] | ||
for d in src.defs: | ||
if isinstance(d, ClassDef): | ||
src_names[d.name] = d | ||
elif isinstance(d, AssignmentStmt): | ||
for l in d.lvalues: | ||
if isinstance(l, NameExpr): | ||
src_names[l.name] = d | ||
|
||
defs = [] # type: List[Statement] | ||
for i in stub.defs: | ||
if isinstance(i, ImportBase): | ||
defs.append(i) | ||
elif isinstance(i, ClassDef): | ||
if i.name not in src_names: | ||
defs.append(i) | ||
elif isinstance(i, AssignmentStmt) and len(i.lvalues) == 1: | ||
entry = i.lvalues[0] | ||
if isinstance(entry, NameExpr): | ||
name = entry.name | ||
if name not in src_names: | ||
defs.append(i) # could be a type alias we add this | ||
else: | ||
# merge it into the source data | ||
src_assignment = src_names[name] | ||
if isinstance(src_assignment, AssignmentStmt): | ||
src_assignment.type = i.type | ||
src_assignment.unanalyzed_type = i.unanalyzed_type | ||
del src_names[name] | ||
defs.extend(src.defs) | ||
src.defs = defs | ||
|
||
def merge_names(self, src: "SymbolTable", stub: "SymbolTable") -> None: | ||
for key, value in stub.items(): | ||
if key not in src: | ||
src[key] = value | ||
else: | ||
main_value = src.get(key) | ||
if main_value is not value: | ||
self.merge_symbols(main_value, value) | ||
|
||
def merge_symbols(self, src: Any, stub: Any) -> None: | ||
t = type(src) | ||
self._line = getattr(src, "line", -1) | ||
if t != type(stub): | ||
msg = "definition conflict of src {} and stub (line {}) {}".format( | ||
src.__class__.__qualname__, stub.line, stub.__class__.__qualname__ | ||
) | ||
self.report(msg) | ||
return | ||
if src is None or isinstance(src, ImportedName): | ||
pass | ||
elif isinstance(src, SymbolTableNode): | ||
src.type_override = stub.type_override | ||
self.merge_symbols(src.node, stub.node) | ||
elif isinstance(src, Var): | ||
self._merge_variables(src, stub) | ||
elif isinstance(src, FuncDef): | ||
self._merge_func_definition(src, stub) | ||
elif isinstance(src, TypeInfo): | ||
self._merge_type_info(src, stub) | ||
elif isinstance(t, AssignmentStmt): | ||
pass | ||
elif isinstance(src, Decorator): | ||
self._merge_decorators(src, stub) | ||
else: | ||
raise RuntimeError("cannot merge {!r} with {!r}".format(src, stub)) | ||
|
||
@staticmethod | ||
def _merge_variables(src, stub): | ||
src.type = stub.type | ||
|
||
def _merge_func_definition(self, src, stub): | ||
if src.arg_names == stub.arg_names: | ||
src.type = stub.type | ||
src.arg_kinds = stub.arg_kinds | ||
src.unanalyzed_type = stub.unanalyzed_type | ||
# merge(left.body, right.body) - does this makes sense ? - we cannot annotate this | ||
else: | ||
msg = "arg conflict of src {} and stub (line {}) {}".format( | ||
repr(src.arg_names), stub.line, repr(stub.arg_names) | ||
) | ||
self.report(msg) | ||
|
||
def _merge_type_info(self, src, stub): | ||
src.type_vars = stub.type_vars | ||
src.metaclass_type = stub.metaclass_type | ||
src.runtime_protocol = stub.runtime_protocol | ||
self.merge_names(src.names, stub.names) | ||
src.defn.type_vars = stub.defn.type_vars | ||
stub_body = stub.defn.defs.body | ||
src_body = src.defn.defs.body | ||
stub_func = {} # type: Dict[str, Union[Decorator, FuncDef]] | ||
stub_assign = {} # type: Dict[str, AssignmentStmt] | ||
for entry in stub_body: | ||
if isinstance(entry, AssignmentStmt): | ||
l_value = entry.lvalues | ||
if len(l_value) != 1: | ||
self.report( | ||
"multiple l-values not supported in stub {}".format(entry.line) | ||
) | ||
stub_lvalue = l_value[0] | ||
if isinstance(stub_lvalue, NameExpr): | ||
stub_assign[stub_lvalue.name] = entry | ||
elif isinstance(entry, Decorator): | ||
stub_func[entry.func.name()] = entry | ||
elif isinstance(entry, FuncDef): | ||
stub_func[entry.name()] = entry | ||
for entry in src_body: | ||
if isinstance(entry, FuncDef): | ||
name = entry.name() | ||
if name in stub_func: | ||
self.merge_symbols(entry, stub_func[name]) | ||
del stub_func[name] | ||
if name == "__init__": | ||
for init_part in entry.body.body: | ||
if isinstance(init_part, AssignmentStmt): | ||
self.class_level_stub_to_init_assign(init_part, stub_assign) | ||
elif isinstance(init_part, IfStmt): | ||
for b in init_part.body: | ||
for part in b.body: | ||
if isinstance(part, AssignmentStmt): | ||
self.class_level_stub_to_init_assign( | ||
part, stub_assign | ||
) | ||
elif isinstance(entry, AssignmentStmt): | ||
if len(entry.lvalues) == 1: | ||
stub_lvalue = entry.lvalues[0] | ||
if isinstance(stub_lvalue, NameExpr): | ||
name = stub_lvalue.name | ||
if name in stub_assign: | ||
self.merge_symbols(entry, stub_assign[name]) | ||
del stub_assign[name] | ||
elif isinstance(entry, Decorator): | ||
name = entry.func.name() | ||
if name in stub_func: | ||
self.merge_symbols(entry, stub_func[name]) | ||
del stub_func[name] | ||
for k, v in stub_assign.items(): | ||
self.report("no source for assign {} @stub:{}".format(k, v.line)) | ||
for f_k, f_v in stub_func.items(): | ||
self.report("no source for func {} @{}".format(f_k, f_v.line)) | ||
|
||
def _merge_decorators(self, src: "Decorator", stub: "Decorator"): | ||
for l, r in itertools.zip_longest(src.decorators, stub.decorators): | ||
if type(l) != type(r): | ||
self.report("decorator type conflict") | ||
break | ||
if isinstance(l, NameExpr): | ||
if not self._decorator_name_match(l, r): | ||
break | ||
elif isinstance(l, CallExpr): | ||
if isinstance(l.callee, NameExpr) and isinstance(r.callee, NameExpr): | ||
if not self._decorator_name_match(l.callee, r.callee): | ||
break | ||
self._decorator_expr_arg_checks(l, r) | ||
else: | ||
self.merge_symbols(src.func, stub.func) | ||
|
||
def _decorator_expr_arg_checks(self, l, r): | ||
for l_arg, r_arg in itertools.zip_longest(l.arg_names, r.arg_names): | ||
if l_arg != r_arg: | ||
msg = "decorator arg name conflict {} vs {}".format(l_arg, r_arg) | ||
self.report(msg, l.line) | ||
break | ||
for n, t in zip(r.arg_names, r.args): | ||
if not isinstance(t, EllipsisExpr): | ||
msg = ( | ||
"stub decorator should not contain default value" | ||
", {} has {}".format(n, type(t).__name__) | ||
) | ||
self.report(msg, l.line) | ||
break | ||
|
||
def _decorator_name_match(self, l: NameExpr, r: NameExpr) -> bool: | ||
if l.name != r.name: | ||
msg = "decorator name conflict {} vs {}".format(l.name, r.name) | ||
self.report(msg) | ||
return False | ||
return True | ||
|
||
@staticmethod | ||
def class_level_stub_to_init_assign(init_part, stub_assign): | ||
if len(init_part.lvalues) == 1: | ||
l_value = init_part.lvalues[0] | ||
if isinstance(l_value, MemberExpr): | ||
name = l_value.name | ||
if name in stub_assign: | ||
init_part.type = stub_assign[name].type | ||
init_part.unanalyzed_type = stub_assign[name].unanalyzed_type | ||
del stub_assign[name] | ||
return True | ||
return False |
Oops, something went wrong.