Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't type check most function bodies if ignoring errors #14150

Merged
merged 26 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ def parse_file(
Raise CompileError if there is a parse error.
"""
t0 = time.time()
if ignore_errors:
self.errors.ignored_files.add(path)
tree = parse(source, path, id, self.errors, options=options)
tree._fullname = id
self.add_stats(
Expand Down
5 changes: 1 addition & 4 deletions mypy/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,7 @@ def split_directive(s: str) -> tuple[list[str], list[str]]:


def mypy_comments_to_config_map(line: str, template: Options) -> tuple[dict[str, str], list[str]]:
"""Rewrite the mypy comment syntax into ini file syntax.

Returns
"""
"""Rewrite the mypy comment syntax into ini file syntax."""
options = {}
entries, errors = split_directive(line)
for entry in entries:
Expand Down
179 changes: 166 additions & 13 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
)
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
from mypy.sharedparse import argument_elide_name, special_function_elide_names
from mypy.traverser import TraverserVisitor
from mypy.types import (
AnyType,
CallableArgument,
Expand Down Expand Up @@ -260,6 +261,11 @@ def parse(
Return the parse tree. If errors is not provided, raise ParseError
on failure. Otherwise, use the errors object to report parse errors.
"""
ignore_errors = (options is not None and options.ignore_errors) or (
errors is not None and fnam in errors.ignored_files
)
# If errors are ignored, we can drop many function bodies to speed up type checking.
strip_function_bodies = ignore_errors and (options is None or not options.preserve_asts)
raise_on_error = False
if options is None:
options = Options()
Expand All @@ -281,7 +287,13 @@ def parse(
warnings.filterwarnings("ignore", category=DeprecationWarning)
ast = ast3_parse(source, fnam, "exec", feature_version=feature_version)

tree = ASTConverter(options=options, is_stub=is_stub_file, errors=errors).visit(ast)
tree = ASTConverter(
options=options,
is_stub=is_stub_file,
errors=errors,
ignore_errors=ignore_errors,
strip_function_bodies=strip_function_bodies,
).visit(ast)
tree.path = fnam
tree.is_stub = is_stub_file
except SyntaxError as e:
Expand Down Expand Up @@ -400,14 +412,24 @@ def is_no_type_check_decorator(expr: ast3.expr) -> bool:


class ASTConverter:
def __init__(self, options: Options, is_stub: bool, errors: Errors) -> None:
# 'C' for class, 'F' for function
self.class_and_function_stack: list[Literal["C", "F"]] = []
def __init__(
self,
options: Options,
is_stub: bool,
errors: Errors,
*,
ignore_errors: bool,
strip_function_bodies: bool,
) -> None:
# 'C' for class, 'D' for function signature, 'F' for function, 'L' for lambda
self.class_and_function_stack: list[Literal["C", "D", "F", "L"]] = []
self.imports: list[ImportBase] = []

self.options = options
self.is_stub = is_stub
self.errors = errors
self.ignore_errors = ignore_errors
self.strip_function_bodies = strip_function_bodies

self.type_ignores: dict[int, list[str]] = {}

Expand Down Expand Up @@ -475,7 +497,12 @@ def get_lineno(self, node: ast3.expr | ast3.stmt) -> int:
return node.lineno

def translate_stmt_list(
self, stmts: Sequence[ast3.stmt], ismodule: bool = False
self,
stmts: Sequence[ast3.stmt],
*,
ismodule: bool = False,
can_strip: bool = False,
is_coroutine: bool = False,
) -> list[Statement]:
# A "# type: ignore" comment before the first statement of a module
# ignores the whole module:
Expand Down Expand Up @@ -504,11 +531,41 @@ def translate_stmt_list(
mark_block_unreachable(block)
return [block]

stack = self.class_and_function_stack
if self.strip_function_bodies and len(stack) == 1 and stack[0] == "F":
return []

res: list[Statement] = []
for stmt in stmts:
node = self.visit(stmt)
res.append(node)

if (
self.strip_function_bodies
and can_strip
and stack[-2:] == ["C", "F"]
and not is_possible_trivial_body(res)
):
# We only strip method bodies if they don't assign to an attribute, as
# this may define an attribute which has an externally visible effect.
visitor = FindAttributeAssign()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to strip statements after last assignment to attribute? This could make a bit more perf gain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be possible! I'd rather not do it in this PR to keep it simple (and the impact is probably pretty minor), but it's seems like a promising follow-up improvement to investigate.

for s in res:
s.accept(visitor)
if visitor.found:
break
else:
if is_coroutine:
# Yields inside an async function affect the return type and should not
# be stripped.
yield_visitor = FindYield()
for s in res:
s.accept(yield_visitor)
if yield_visitor.found:
break
else:
return []
else:
return []
return res

def translate_type_comment(
Expand Down Expand Up @@ -573,9 +630,20 @@ def as_block(self, stmts: list[ast3.stmt], lineno: int) -> Block | None:
b.set_line(lineno)
return b

def as_required_block(self, stmts: list[ast3.stmt], lineno: int) -> Block:
def as_required_block(
self,
stmts: list[ast3.stmt],
lineno: int,
*,
can_strip: bool = False,
is_coroutine: bool = False,
) -> Block:
assert stmts # must be non-empty
b = Block(self.fix_function_overloads(self.translate_stmt_list(stmts)))
b = Block(
self.fix_function_overloads(
self.translate_stmt_list(stmts, can_strip=can_strip, is_coroutine=is_coroutine)
)
)
# TODO: in most call sites line is wrong (includes first line of enclosing statement)
# TODO: also we need to set the column, and the end position here.
b.set_line(lineno)
Expand Down Expand Up @@ -831,9 +899,6 @@ def _is_stripped_if_stmt(self, stmt: Statement) -> bool:
# For elif, IfStmt are stored recursively in else_body
return self._is_stripped_if_stmt(stmt.else_body.body[0])

def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ["C", "F"]

def translate_module_id(self, id: str) -> str:
"""Return the actual, internal module id for a source text id."""
if id == self.options.custom_typing_module:
Expand Down Expand Up @@ -868,7 +933,7 @@ def do_func_def(
self, n: ast3.FunctionDef | ast3.AsyncFunctionDef, is_coroutine: bool = False
) -> FuncDef | Decorator:
"""Helper shared between visit_FunctionDef and visit_AsyncFunctionDef."""
self.class_and_function_stack.append("F")
self.class_and_function_stack.append("D")
no_type_check = bool(
n.decorator_list and any(is_no_type_check_decorator(d) for d in n.decorator_list)
)
Expand Down Expand Up @@ -915,7 +980,8 @@ def do_func_def(
return_type = TypeConverter(self.errors, line=lineno).visit(func_type_ast.returns)

# add implicit self type
if self.in_method_scope() and len(arg_types) < len(args):
in_method_scope = self.class_and_function_stack[-2:] == ["C", "D"]
if in_method_scope and len(arg_types) < len(args):
arg_types.insert(0, AnyType(TypeOfAny.special_form))
except SyntaxError:
stripped_type = n.type_comment.split("#", 2)[0].strip()
Expand Down Expand Up @@ -965,7 +1031,10 @@ def do_func_def(
end_line = getattr(n, "end_lineno", None)
end_column = getattr(n, "end_col_offset", None)

func_def = FuncDef(n.name, args, self.as_required_block(n.body, lineno), func_type)
self.class_and_function_stack.pop()
self.class_and_function_stack.append("F")
body = self.as_required_block(n.body, lineno, can_strip=True, is_coroutine=is_coroutine)
func_def = FuncDef(n.name, args, body, func_type)
if isinstance(func_def.type, CallableType):
# semanal.py does some in-place modifications we want to avoid
func_def.unanalyzed_type = func_def.type.copy_modified()
Expand Down Expand Up @@ -1409,9 +1478,11 @@ def visit_Lambda(self, n: ast3.Lambda) -> LambdaExpr:
body.lineno = n.body.lineno
body.col_offset = n.body.col_offset

self.class_and_function_stack.append("L")
e = LambdaExpr(
self.transform_args(n.args, n.lineno), self.as_required_block([body], n.lineno)
)
self.class_and_function_stack.pop()
e.set_line(n.lineno, n.col_offset) # Overrides set_line -- can't use self.set_line
return e

Expand Down Expand Up @@ -2081,3 +2152,85 @@ def stringify_name(n: AST) -> str | None:
if sv is not None:
return f"{sv}.{n.attr}"
return None # Can't do it.


class FindAttributeAssign(TraverserVisitor):
"""Check if an AST contains attribute assignments (e.g. self.x = 0)."""

def __init__(self) -> None:
self.lvalue = False
self.found = False

def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you also need to check assignment expression? (i.e. walrus a.k.a :=)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't support assigning to an attribute.

self.lvalue = True
for lv in s.lvalues:
lv.accept(self)
self.lvalue = False

def visit_with_stmt(self, s: WithStmt) -> None:
self.lvalue = True
for lv in s.target:
if lv is not None:
lv.accept(self)
self.lvalue = False
s.body.accept(self)

def visit_for_stmt(self, s: ForStmt) -> None:
self.lvalue = True
s.index.accept(self)
self.lvalue = False
s.body.accept(self)
if s.else_body:
s.else_body.accept(self)

def visit_expression_stmt(self, s: ExpressionStmt) -> None:
# No need to look inside these
pass

def visit_call_expr(self, e: CallExpr) -> None:
# No need to look inside these
pass

def visit_index_expr(self, e: IndexExpr) -> None:
# No need to look inside these
pass

def visit_member_expr(self, e: MemberExpr) -> None:
if self.lvalue:
self.found = True


class FindYield(TraverserVisitor):
"""Check if an AST contains yields or yield froms."""

def __init__(self) -> None:
self.found = False

def visit_yield_expr(self, e: YieldExpr) -> None:
self.found = True

def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
self.found = True


def is_possible_trivial_body(s: list[Statement]) -> bool:
"""Could the statements form a "trivial" function body, such as 'pass'?

This mimics mypy.semanal.is_trivial_body, but this runs before
semantic analysis so some checks must be conservative.
"""
l = len(s)
if l == 0:
return False
i = 0
if isinstance(s[0], ExpressionStmt) and isinstance(s[0].expr, StrExpr):
# Skip docstring
i += 1
if i == l:
return True
if l > i + 1:
return False
stmt = s[i]
return isinstance(stmt, (PassStmt, RaiseStmt)) or (
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
)
11 changes: 9 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6684,7 +6684,7 @@ def is_trivial_body(block: Block) -> bool:
"..." (ellipsis), or "raise NotImplementedError()". A trivial body may also
start with a statement containing just a string (e.g. a docstring).

Note: functions that raise other kinds of exceptions do not count as
Note: Functions that raise other kinds of exceptions do not count as
"trivial". We use this function to help us determine when it's ok to
relax certain checks on body, but functions that raise arbitrary exceptions
are more likely to do non-trivial work. For example:
Expand All @@ -6694,11 +6694,18 @@ def halt(self, reason: str = ...) -> NoReturn:

A function that raises just NotImplementedError is much less likely to be
this complex.

Note: If you update this, you may also need to update
mypy.fastparse.is_possible_trivial_body!
"""
body = block.body
if not body:
# Functions have empty bodies only if the body is stripped or the function is
# generated or deserialized. In these cases the body is unknown.
return False

# Skip a docstring
if body and isinstance(body[0], ExpressionStmt) and isinstance(body[0].expr, StrExpr):
if isinstance(body[0], ExpressionStmt) and isinstance(body[0].expr, StrExpr):
body = block.body[1:]

if len(body) == 0:
Expand Down
1 change: 1 addition & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,7 @@ def mypy_options(stubgen_options: Options) -> MypyOptions:
options.python_version = stubgen_options.pyversion
options.show_traceback = True
options.transform_source = remove_misplaced_type_comments
options.preserve_asts = True
return options


Expand Down
15 changes: 10 additions & 5 deletions mypy/test/testparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from pytest import skip

from mypy import defaults
from mypy.config_parser import parse_mypy_comments
from mypy.errors import CompileError
from mypy.options import Options
from mypy.parse import parse
from mypy.test.data import DataDrivenTestCase, DataSuite
from mypy.test.helpers import assert_string_arrays_equal, find_test_files, parse_options
from mypy.util import get_mypy_comments


class ParserSuite(DataSuite):
Expand Down Expand Up @@ -40,13 +42,16 @@ def test_parser(testcase: DataDrivenTestCase) -> None:
else:
options.python_version = defaults.PYTHON3_VERSION

source = "\n".join(testcase.input)

# Apply mypy: comments to options.
comments = get_mypy_comments(source)
changes, _ = parse_mypy_comments(comments, options)
options = options.apply_changes(changes)

try:
n = parse(
bytes("\n".join(testcase.input), "ascii"),
fnam="main",
module="__main__",
errors=None,
options=options,
bytes(source, "ascii"), fnam="main", module="__main__", errors=None, options=options
)
a = n.str_with_options(options).split("\n")
except CompileError as e:
Expand Down
Loading