diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b7c9e14..b39d6c9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ Bugfixes: [PEP 570 syntax](https://peps.python.org/pep-0570/) and the first positional-or-keyword parameter following the positional-only parameters used a custom TypeVar (see #455). +* Y046: Fix false negative where an unused protocol would not be detected if + the protocol was generic. ## 23.11.0 diff --git a/pyi.py b/pyi.py index 25c41553..6bd7d834 100644 --- a/pyi.py +++ b/pyi.py @@ -7,11 +7,11 @@ import re import sys from collections import Counter, defaultdict -from collections.abc import Container, Iterable, Iterator, Sequence +from collections.abc import Container, Iterable, Iterator, Sequence, Set as AbstractSet from contextlib import contextmanager from copy import deepcopy from dataclasses import dataclass -from functools import partial +from functools import cached_property, partial from itertools import chain, groupby, zip_longest from keyword import iskeyword from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Union @@ -438,38 +438,6 @@ def _analyse_exit_method_arg(node: ast.BinOp) -> ExitArgAnalysis: return ExitArgAnalysis(is_union_with_None=False, non_None_part=None) -def _is_decorated_with_final( - node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef, -) -> bool: - return any(_is_final(decorator) for decorator in node.decorator_list) - - -def _get_collections_abc_obj_id(node: ast.expr | None) -> str | None: - """ - If the node represents a subscripted object from collections.abc or typing, - return the name of the object. - Else, return None. - - >>> _get_collections_abc_obj_id(_ast_node_for('AsyncIterator[str]')) - 'AsyncIterator' - >>> _get_collections_abc_obj_id(_ast_node_for('typing.AsyncIterator[str]')) - 'AsyncIterator' - >>> node = _ast_node_for('typing_extensions.AsyncIterator[str]') - >>> _get_collections_abc_obj_id(node) - 'AsyncIterator' - >>> _get_collections_abc_obj_id(_ast_node_for('collections.abc.AsyncIterator[str]')) - 'AsyncIterator' - >>> node = _ast_node_for('collections.OrderedDict[str, int]') - >>> _get_collections_abc_obj_id(node) is None - True - """ - if not isinstance(node, ast.Subscript): - return None - return _get_name_of_class_if_from_modules( - node.value, modules=_TYPING_OR_COLLECTIONS_ABC - ) - - _INPLACE_BINOP_METHODS = frozenset( { "__iadd__", @@ -490,12 +458,12 @@ def _get_collections_abc_obj_id(node: ast.expr | None) -> str | None: def _has_bad_hardcoded_returns( - method: ast.FunctionDef | ast.AsyncFunctionDef, *, classdef: ast.ClassDef + method: ast.FunctionDef | ast.AsyncFunctionDef, *, class_ctx: EnclosingClassContext ) -> bool: """Return `True` if `function` should be rewritten with `typing_extensions.Self`.""" # PEP 673 forbids the use of `typing(_extensions).Self` in metaclasses. # Do our best to avoid false positives here: - if _is_metaclass(classdef): + if class_ctx.is_metaclass: return False # Much too complex for our purposes to worry @@ -514,28 +482,34 @@ def _has_bad_hardcoded_returns( if isinstance(method, ast.AsyncFunctionDef): return ( method_name == "__aenter__" - and _is_name(returns, classdef.name) - and not _is_decorated_with_final(classdef) + and _is_name(returns, class_ctx.cls_name) + and not class_ctx.is_decorated_with_final ) if method_name in _INPLACE_BINOP_METHODS: return returns is not None and not _is_Self(returns) - if _is_name(returns, classdef.name): - return method_name in {"__enter__", "__new__"} and not _is_decorated_with_final( - classdef + if _is_name(returns, class_ctx.cls_name): + return ( + method_name in {"__enter__", "__new__"} + and not class_ctx.is_decorated_with_final ) - return_obj_name = _get_collections_abc_obj_id(returns) - bases = {_get_collections_abc_obj_id(base_node) for base_node in classdef.bases} - - if method_name == "__iter__": - return return_obj_name in {"Iterable", "Iterator"} and "Iterator" in bases - elif method_name == "__aiter__": - return ( - return_obj_name in {"AsyncIterable", "AsyncIterator"} - and "AsyncIterator" in bases + if isinstance(returns, ast.Subscript): + return_obj_name = _get_name_of_class_if_from_modules( + returns.value, modules=_TYPING_OR_COLLECTIONS_ABC ) + if method_name == "__iter__": + bad_returns = {"Iterable", "Iterator"} + return return_obj_name in bad_returns and class_ctx.contains_in_bases( + "Iterator", from_=_TYPING_OR_COLLECTIONS_ABC + ) + elif method_name == "__aiter__": + bad_returns = {"AsyncIterable", "AsyncIterator"} + return return_obj_name in bad_returns and class_ctx.contains_in_bases( + "AsyncIterator", from_=_TYPING_OR_COLLECTIONS_ABC + ) + return False @@ -708,6 +682,102 @@ def _analyse_typing_Literal(node: ast.Subscript) -> TypingLiteralAnalysis: ) +_KNOWN_ENUM_BASES = frozenset( + {"Enum", "Flag", "IntEnum", "IntFlag", "StrEnum", "ReprEnum"} +) + + +_COMMON_METACLASSES = { + "type": "builtins", + "ABCMeta": "abc", + "EnumMeta": "enum", + "EnumType": "enum", +} + + +@dataclass(frozen=True) +class EnclosingClassContext: + node: ast.ClassDef + cls_name: str + bases_map: defaultdict[str, set[str | None]] + + def contains_in_bases(self, obj: str, *, from_: AbstractSet[str]) -> bool: + if obj not in self.bases_map: + return False + if None in self.bases_map[obj]: + return True + return bool(self.bases_map[obj] & from_) + + @cached_property + def is_protocol_class(self) -> bool: + return self.contains_in_bases("Protocol", from_=_TYPING_MODULES) + + @cached_property + def is_typeddict_class(self) -> bool: + return self.contains_in_bases( + "TypedDict", from_=(_TYPING_MODULES | {"mypy_extensions"}) + ) + + @cached_property + def is_enum_class(self) -> bool: + return any( + self.contains_in_bases(enum_cls, from_={"enum"}) + for enum_cls in _KNOWN_ENUM_BASES + ) + + @cached_property + def is_metaclass(self) -> bool: + return any( + self.contains_in_bases(metacls, from_={module}) + for metacls, module in _COMMON_METACLASSES.items() + ) + + @cached_property + def is_decorated_with_final(self) -> bool: + return any(_is_final(decorator) for decorator in self.node.decorator_list) + + +class ClassBase(NamedTuple): + module: str | None + obj: str + + +def _analyze_classdef(node: ast.ClassDef) -> EnclosingClassContext: + bases_map: defaultdict[str, set[str | None]] = defaultdict(set) + + def _unravel(node: ast.expr) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + value = _unravel(node.value) + if value is None: + return None + return f"{value}.{node.attr}" + return None + + def _analyze_base_node( + base_node: ast.expr, top_level: bool = True + ) -> ClassBase | None: + if isinstance(base_node, ast.Name): + return ClassBase(None, base_node.id) + if isinstance(base_node, ast.Attribute): + value = _unravel(base_node.value) + if value is None: + return None + return ClassBase(value, base_node.attr) + if isinstance(base_node, ast.Subscript) and top_level: + return _analyze_base_node(base_node.value, top_level=False) + return None + + for base_node in node.bases: + base = _analyze_base_node(base_node) + if base is None: + continue + bases_map[base.obj].add(base.module) + + return EnclosingClassContext(node=node, cls_name=node.name, bases_map=bases_map) + + _ALLOWED_MATH_ATTRIBUTES_IN_DEFAULTS = frozenset( {"math.inf", "math.nan", "math.e", "math.pi", "math.tau"} ) @@ -854,49 +924,6 @@ def _is_valid_default_value_without_annotation(node: ast.expr) -> bool: ) -_KNOWN_ENUM_BASES = frozenset( - {"Enum", "Flag", "IntEnum", "IntFlag", "StrEnum", "ReprEnum"} -) - - -def _is_enum_base(node: ast.expr) -> bool: - if isinstance(node, ast.Name): - return node.id in _KNOWN_ENUM_BASES - return ( - isinstance(node, ast.Attribute) - and isinstance(node.value, ast.Name) - and node.value.id == "enum" - and node.attr in _KNOWN_ENUM_BASES - ) - - -def _is_enum_class(node: ast.ClassDef) -> bool: - return any(_is_enum_base(base) for base in node.bases) - - -_COMMON_METACLASSES = { - "type": "builtins", - "ABCMeta": "abc", - "EnumMeta": "enum", - "EnumType": "enum", -} - - -def _is_metaclass_base(node: ast.expr) -> bool: - if isinstance(node, ast.Name): - return node.id in _COMMON_METACLASSES - return ( - isinstance(node, ast.Attribute) - and node.attr in _COMMON_METACLASSES - and _is_name(node.value, _COMMON_METACLASSES[node.attr]) - ) - - -def _is_metaclass(node: ast.ClassDef) -> bool: - """Best-effort attempt to determine if a class is a metaclass or not.""" - return any(_is_metaclass_base(base) for base in node.bases) - - def _check_import_or_attribute( node: ast.Attribute | ast.ImportFrom, module_name: str, object_name: str ) -> str | None: @@ -1000,11 +1027,10 @@ class PyiVisitor(ast.NodeVisitor): string_literals_allowed: NestingCounter long_strings_allowed: NestingCounter in_function: NestingCounter - in_class: NestingCounter visiting_arg: NestingCounter # This is only relevant for visiting classes - current_class_node: ast.ClassDef | None = None + enclosing_class_ctx: EnclosingClassContext | None = None def __init__(self, filename: str) -> None: self.filename = filename @@ -1018,12 +1044,18 @@ def __init__(self, filename: str) -> None: self.string_literals_allowed = NestingCounter() self.long_strings_allowed = NestingCounter() self.in_function = NestingCounter() - self.in_class = NestingCounter() self.visiting_arg = NestingCounter() def __repr__(self) -> str: return f"{self.__class__.__name__}(filename={self.filename!r})" + @property + def visiting_enum_class(self) -> bool: + return ( + self.enclosing_class_ctx is not None + and self.enclosing_class_ctx.is_enum_class + ) + def visit_Attribute(self, node: ast.Attribute) -> None: self.generic_visit(node) if error_msg := _check_import_or_attribute( @@ -1084,11 +1116,7 @@ def _check_default_value_without_type_annotation( return if _is_valid_default_value_with_annotation(assignment): # Annoying special-casing to exclude enums from Y052 - if self.in_class.active: - assert self.current_class_node is not None - if not _is_enum_class(self.current_class_node): - self.error(node, Y052.format(variable=target_name)) - else: + if not self.visiting_enum_class: self.error(node, Y052.format(variable=target_name)) else: self.error(node, Y015) @@ -1109,7 +1137,7 @@ def visit_Assign(self, node: ast.Assign) -> None: self.error(node, Y017) target = target_name = None is_special_assignment = _is_assignment_which_must_have_a_value( - target_name, in_class=self.in_class.active + target_name, in_class=self.enclosing_class_ctx is not None ) assignment = node.value if isinstance(assignment, ast.Call): @@ -1269,7 +1297,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: is_special_assignment = isinstance( node_target, ast.Name ) and _is_assignment_which_must_have_a_value( - node_target.id, in_class=self.in_class.active + node_target.id, in_class=self.enclosing_class_ctx is not None ) is_typealias = _is_TypeAlias(node_annotation) and isinstance( @@ -1659,22 +1687,18 @@ def _check_class_bases(self, bases: list[ast.expr]) -> None: self.error(Generic_basenode, msg) def visit_ClassDef(self, node: ast.ClassDef) -> None: - if node.name.startswith("_") and not self.in_class.active: - for base in node.bases: - if _is_Protocol(base): - self.protocol_defs[node.name].append(node) - break - if _is_TypedDict(base): - self.class_based_typeddicts[node.name].append(node) - break - - old_class_node = self.current_class_node - self.current_class_node = node - with self.in_class.enabled(): - self.generic_visit(node) - self.current_class_node = old_class_node + old_context = self.enclosing_class_ctx + self.enclosing_class_ctx = _analyze_classdef(node) + + if node.name.startswith("_"): + if self.enclosing_class_ctx.is_protocol_class: + self.protocol_defs[node.name].append(node) + elif self.enclosing_class_ctx.is_typeddict_class: + self.class_based_typeddicts[node.name].append(node) + self.generic_visit(node) self._check_class_bases(node.bases) + self.enclosing_class_ctx = old_context # empty class body should contain "..." not "pass" if len(node.body) == 1: @@ -1889,14 +1913,14 @@ def _check_aiter_returns( example_returns = f"AsyncIterator[{unparse(returns.slice.elts[0])}]" self._Y058_error(node, non_kw_only_args, example_returns) - def _visit_synchronous_method(self, node: ast.FunctionDef) -> None: + def _visit_synchronous_method( + self, node: ast.FunctionDef, class_ctx: EnclosingClassContext + ) -> None: method_name = node.name all_args = node.args - classdef = self.current_class_node - assert classdef is not None - if _has_bad_hardcoded_returns(node, classdef=classdef): - return self._Y034_error(node=node, cls_name=classdef.name) + if _has_bad_hardcoded_returns(node, class_ctx=class_ctx): + return self._Y034_error(node=node, cls_name=class_ctx.cls_name) returns = node.returns @@ -1930,17 +1954,15 @@ def _visit_synchronous_method(self, node: ast.FunctionDef) -> None: self.error(node, Y032.format(method_name=method_name)) def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - if self.in_class.active: - self._visit_synchronous_method(node) + if self.enclosing_class_ctx is not None: + self._visit_synchronous_method(node, self.enclosing_class_ctx) self._visit_function(node) def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: - if self.in_class.active: - classdef = self.current_class_node - assert classdef is not None + if self.enclosing_class_ctx is not None: method_name = node.name - if _has_bad_hardcoded_returns(node, classdef=classdef): - self._Y034_error(node=node, cls_name=classdef.name) + if _has_bad_hardcoded_returns(node, class_ctx=self.enclosing_class_ctx): + self._Y034_error(node=node, cls_name=self.enclosing_class_ctx.cls_name) elif method_name == "__aexit__": self._check_exit_method(node=node, method_name=method_name) self._visit_function(node) @@ -2075,7 +2097,7 @@ def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: ): self.error(statement, Y010) - if self.in_class.active: + if self.enclosing_class_ctx is not None: self.check_self_typevars(node) def visit_arg(self, node: ast.arg) -> None: diff --git a/tests/unused_things.pyi b/tests/unused_things.pyi index ddb05337..e602ec9b 100644 --- a/tests/unused_things.pyi +++ b/tests/unused_things.pyi @@ -6,9 +6,14 @@ import mypy_extensions import typing_extensions from typing_extensions import Literal, TypeAlias +_T = TypeVar("_T") + class _Foo(Protocol): # Y046 Protocol "_Foo" is not used bar: int +class _GenericFoo(Protocol[_T]): # Y046 Protocol "_GenericFoo" is not used + bar: _T + class _Bar(typing.Protocol): # Y046 Protocol "_Bar" is not used bar: int @@ -21,7 +26,11 @@ class UnusedButPublicProtocol(Protocol): class _UsedPrivateProtocol(Protocol): bar: int +class _UsedPrivateGenericProtocol(Protocol[_T]): + bar: _T + def uses__UsedPrivateProtocol(arg: _UsedPrivateProtocol) -> None: ... +def uses__UsedPrivateGenericProtocol(arg: _UsedPrivateGenericProtocol) -> None: ... _UnusedPrivateAlias: TypeAlias = str | int # Y047 Type alias "_UnusedPrivateAlias" is not used PublicAlias: TypeAlias = str | int @@ -83,3 +92,8 @@ else: class _ConditionallyDefinedUnusedProtocol(Protocol): foo: int bar: str + +# Tests to make sure we handle edge-cases in the AST correctly: +class Strange: ... +class _NotAProtocol(Strange[Strange][Strange].Protocol): ... +class _AlsoNotAProtocol(Protocol[Protocol][Protocol][Protocol]): ...