Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborbernat committed Jun 4, 2018
1 parent 60b6927 commit 80c5a52
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 48 deletions.
12 changes: 5 additions & 7 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,12 +2209,8 @@ def merge_with(self, state: 'State', errors: Errors) -> None:
self.priorities = priorities

self.dep_line_map.update({k: -v for k, v in state.dep_line_map.items()})
e_rep = StubOrSrcErrors(errors, (self.xpath, self.id), (state.xpath, state.id))
if self.tree is not None:
self.tree.merge(state.tree, e_rep)
else:
import pdb
pdb.set_trace()
e_rep = StubOrSrcErrors(errors, (self.xpath, self.id))
self.tree.merge(state.tree, e_rep)


# Module import and diagnostic glue
Expand Down Expand Up @@ -2522,6 +2518,7 @@ def load_graph(sources: List[BuildSource], manager: BuildManager,
mw = bs.merge_with
src_st = State(id=mw.module, path=mw.path, source=mw.text,
manager=manager, root_source=True)
src_st.parse_file()
src_st.merge_with(st, manager.errors)
st = src_st
except ModuleNotFound:
Expand Down Expand Up @@ -2848,7 +2845,8 @@ def process_stale_scc(graph: Graph, scc: List[str], manager: BuildManager) -> No
graph[id].finish_passes()
if manager.options.cache_fine_grained or manager.options.fine_grained_incremental:
graph[id].compute_fine_grained_deps()
manager.flush_errors(manager.errors.file_messages(graph[id].xpath), False)
filename = graph[id].xpath
manager.flush_errors(manager.errors.file_messages(filename), False)
graph[id].write_cache()
graph[id].mark_as_rechecked()

Expand Down
50 changes: 26 additions & 24 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,11 @@ def deserialize(cls, data: JsonDict) -> 'SymbolNode':

class StubOrSrcErrors:

def __init__(self, errors: 'Errors', src, stub):
def __init__(self, errors: 'Errors', src):
self.errors = errors
self.other = stub
self.errors.set_file(*src)
self.is_src = True

def report(self, line, msg, is_src=True):
if not (self.is_src and is_src):
x, i = self.errors.file, self.errors.target_module
self.errors.set_file(*self.other)
self.other = x, i
self.is_src = not self.is_src

def report(self, line, msg):
self.errors.report(line, 0, msg)


Expand Down Expand Up @@ -377,9 +370,8 @@ def merge_symbols(src: Any, stub: Any, errors: 'StubOrSrcErrors') -> None:
if isinstance(entry, AssignmentStmt):
l_value = entry.lvalues
if len(l_value) != 1:
errors.report(entry.line,
"multiple l-values not supported in stub",
is_src=False)
errors.report(0,
"multiple l-values not supported in stub {}".format(entry.line))
stub_assign[l_value[0].name] = entry
elif isinstance(entry, Decorator):
stub_func[entry.func.name()] = entry
Expand All @@ -394,15 +386,12 @@ def merge_symbols(src: Any, stub: Any, errors: 'StubOrSrcErrors') -> None:
if name == '__init__':
for init_part in entry.body.body:
if isinstance(init_part, AssignmentStmt):
if len(init_part.lvalues) == 1:
l_value = init_part.lvalues[0]
if isinstance(l_value, MemberExpr):
name = l_value.name
if name in stub_assign:
init_part.type = stub_assign[name].type
init_part.unanalyzed_type = stub_assign[
name].unanalyzed_type
del stub_assign[name]
class_level_stub_to_init_assign(init_part, stub_assign)
elif isinstance(init_part, IfStmt):
for b in init_part.body:
for part in b.body:
if isinstance(part, AssignmentStmt):
class_level_stub_to_init_assign(part, stub_assign)
elif isinstance(entry, AssignmentStmt):
name = entry.lvalues[0].name
if name in stub_assign:
Expand All @@ -414,9 +403,9 @@ def merge_symbols(src: Any, stub: Any, errors: 'StubOrSrcErrors') -> None:
merge_symbols(entry, stub_func[name], errors)
del stub_func[name]
for k, v in stub_assign.items():
errors.report(v.line, 'no source for assign {}'.format(k), is_src=False)
errors.report(src.line, 'no source for assign {} @stub:{}'.format(k, v.line))
for k, v in stub_func.items():
errors.report(v.line, 'no source for func {}'.format(k), is_src=False)
errors.report(src.line, 'no source for func {} @{}'.format(k, v.line))
elif t == AssignmentStmt:
pass
elif t == Decorator:
Expand All @@ -439,6 +428,19 @@ def merge_symbols(src: Any, stub: Any, errors: 'StubOrSrcErrors') -> None:
raise RuntimeError('cannot merge {!r} with {!r}'.format(src, stub))


def class_level_stub_to_init_assign(init_part, stub_assign):
if len(init_part.lvalues) == 1:
l_value = init_part.lvalues[0]
if isinstance(l_value, MemberExpr):
name = l_value.name
if name in stub_assign:
init_part.type = stub_assign[name].type
init_part.unanalyzed_type = stub_assign[name].unanalyzed_type
del stub_assign[name]
return True
return False


class ImportBase(Statement):
"""Base class for all import statements."""

Expand Down
49 changes: 32 additions & 17 deletions mypy/test/test_stub_src_merge.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import re
import textwrap
from typing import Callable, List
from typing import Callable

from mypy import build
from mypy.build import BuildSource, BuildResult
from mypy.build import BuildResult, BuildSource
from mypy.defaults import PYTHON3_VERSION
from mypy.options import Options

Expand All @@ -24,7 +24,7 @@ def checker_resource_fixture():


@pytest.fixture(name="checker")
def checker_fixture(checker_resource, tmpdir) -> Callable[[str, str, str], BuildResult]:
def checker_fixture(checker_resource, tmpdir, capsys) -> Callable[[str, str, str], BuildResult]:
line_nr, options = checker_resource

def _do(src, stub, result):
Expand All @@ -37,12 +37,13 @@ def _do(src, stub, result):
with open(path, 'wt') as f:
f.write(textwrap.dedent(content))

combined_ast, res = _build_ast(
combined_ast, combined_res = _build_ast(
BuildSource(stub_path, None, None,
merge_with=BuildSource(src_path, None, None)))
result_ast, _ = _build_ast(BuildSource(result_path, None, None))
assert combined_ast == result_ast
return res
result_ast, result_res = _build_ast(BuildSource(result_path, None, None))
capsys.readouterr()
assert result_ast == combined_ast
return combined_res

def _build_ast(source):
path = source.merge_with.path if source.merge_with is not None else source.path
Expand Down Expand Up @@ -80,25 +81,39 @@ def test_class_init_self_can_be_defined_at_class_level(checker):
result = checker("""
class A:
def __init__(self):
self.a = 1""", """
self.a = 's'""", """
class A:
a : str = ...
def __init__(self) -> None: ...""", """
class A:
def __init__(self) -> None:
self.a : str = 1""")
self.a : str = 's'""")
assert not result.errors


def test_cache_stub_modified(checker):
def test_class_init_self_can_be_defined_at_class_level_inside_if(checker):
result = checker("""
def fancy_add(a, b = None):
pass""", """
from typing import Union
def fancy_add(a : int, b : Union[None, int] = ...) -> int: ...""", """
from typing import Union
def fancy_add(a : int, b : Union[None, int] = None) -> int:
pass""")
class A:
def __init__(self):
if 0 > 10:
self.a = 0
else:
self.a = 1""", """
class A:
a : int = ...
def __init__(self) -> None: ...""", """
class A:
def __init__(self) -> None:
if 0 > 10:
self.a : int = 0
else:
self.a = 1""")
assert not result.errors


def test_variable_module_level(checker):
result = checker("""
a = 1""", """
a : int = ...""", """
a : int = 1""")
assert not result.errors

0 comments on commit 80c5a52

Please sign in to comment.