Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more precise inference for enum attributes #6867

Merged
merged 7 commits into from
May 30, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions mypy/plugins/enums.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
"""
This file contains a variety of plugins for refining how mypy infers types of
expressions involving Enums.

Currently, this file focuses on providing better inference for expressions like
'SomeEnum.FOO.name' and 'SomeEnum.FOO.value'. Note that the type of both expressions
will vary depending on exactly which instance of SomeEnum we're looking at.

Note that this file does *not* contain all special-cased logic related to enums:
we actually bake some of it directly in to the semantic analysis layer (see
semanal_enum.py).
"""
from typing import Optional
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
from typing_extensions import Final
import mypy.plugin # To avoid circular imports.
from mypy.types import Type, Instance, LiteralType
from mypy.nodes import Var, MDEF

# Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
# enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
ENUM_PREFIXES = ['enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag']
ENUM_NAME_ACCESS = (
['{}.name'.format(prefix) for prefix in ENUM_PREFIXES]
+ ['{}._name_'.format(prefix) for prefix in ENUM_PREFIXES]
ENUM_PREFIXES: Final = {'enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'}
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
ENUM_NAME_ACCESS: Final = (
{'{}.name'.format(prefix) for prefix in ENUM_PREFIXES}
| {'{}._name_'.format(prefix) for prefix in ENUM_PREFIXES}
)
ENUM_VALUE_ACCESS = (
['{}.value'.format(prefix) for prefix in ENUM_PREFIXES]
+ ['{}._value_'.format(prefix) for prefix in ENUM_PREFIXES]
ENUM_VALUE_ACCESS: Final = (
{'{}.value'.format(prefix) for prefix in ENUM_PREFIXES}
| {'{}._value_'.format(prefix) for prefix in ENUM_PREFIXES}
)


def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
enum_field_name = extract_underlying_field_name(ctx.type)
"""This plugin refines the 'name' attribute in enums to act as if
they were declared to be final.

For example, the expression 'MyEnum.FOO.name' normally is inferred
to be of type 'str'.

This plugin will instead make the inferred type be a 'str' where the
last known value is 'Literal["FOO"]'. This means it would be legal to
use 'MyEnum.FOO.name' in contexts that expect a Literal type, just like
any other Final variable or attribute.

This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_NAME_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
return ctx.default_attr_type
else:
Expand All @@ -27,7 +53,27 @@ def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:


def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
enum_field_name = extract_underlying_field_name(ctx.type)
"""This plugin refines the 'value' attribute in enums to refer to
the original underlying value. For example, suppose we have the
following:

class SomeEnum:
FOO = A()
BAR = B()

By default, mypy will infer that 'SomeEnum.FOO.value' and
'SomeEnum.BAR.value' both are of type 'Any'. This plugin refines
this inference so that mypy understands the expressions are
actually of types 'A' and 'B' respectively. This better reflects
the actual runtime behavior.

This plugin works simply by looking up the original value assigned
to the enum. For example,
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_VALUE_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
return ctx.default_attr_type

Expand All @@ -51,7 +97,18 @@ def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
return underlying_type


def extract_underlying_field_name(typ: Type) -> Optional[str]:
def _extract_underlying_field_name(typ: Type) -> Optional[str]:
"""If the given type corresponds to some Enum instance, returns the
original name of that enum. For example, if we receive in the type
corresponding to 'SomeEnum.FOO', we return the string "SomeEnum.Foo".

This helper takes advantage of the fact that Enum instances are valid
to use inside Literal[...] types. An expression like 'SomeEnum.FOO' is
actually represented by an Instance type with a Literal enum fallback.

We can examine this Literal fallback to retrieve the string.
"""

if not isinstance(typ, Instance):
return None

Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ class Test(Enum):
b = auto()

reveal_type(Test.a) # E: Revealed type is '__main__.Test'
[builtins fixtures/primitives.pyi]

[case testEnumAttributeAccessMatrix]
from enum import Enum, IntEnum, IntFlag, Flag, EnumMeta, auto
Expand Down Expand Up @@ -553,9 +554,10 @@ reveal_type(D3.x.value) # E: Revealed type is 'builtins.int'
reveal_type(D3.x._value_) # E: Revealed type is 'builtins.int'

# TODO: Generalize our enum functional API logic to work with subclasses of Enum
# See https://github.com/python/mypy/issues/6037

class Parent(Enum): pass
#E1 = Parent('E1', 'x')
# E1 = Parent('E1', 'x') # See above TODO
class E2(Parent):
x = auto()
class E3(Parent):
Expand Down Expand Up @@ -586,6 +588,7 @@ F3.x.name # E: "F3" has no attribute "name"
F3.x._name_ # E: "F3" has no attribute "_name_"
F3.x.value # E: "F3" has no attribute "value"
F3.x._value_ # E: "F3" has no attribute "_value_"
[builtins fixtures/primitives.pyi]

[case testEnumAttributeChangeIncremental]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that testing deserialization of the related types would also be an interesting test case. I wonder if one exists?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure if we have one. Do you know which file I should add the test to? (I don't remember where we keep the deserialization tests.)

from a import SomeEnum
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -2427,7 +2427,7 @@ from typing import Protocol
class P(Protocol): ...
class C(P): ...

reveal_type(C.register(int)) # E: Revealed type is 'def (x: builtins.object =, base: builtins.int =) -> builtins.int'
reveal_type(C.register(int)) # E: Revealed type is 'def () -> builtins.int'
[typing fixtures/typing-full.pyi]
[out]

Expand Down
2 changes: 2 additions & 0 deletions test-data/unit/fixtures/primitives.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class type:
def __init__(self, x) -> None: pass

class int:
# Note: this is a simplification of the actual signature
def __init__(self, x: object = ..., base: int = ...) -> None: pass
def __add__(self, i: int) -> int: pass
class float:
def __float__(self) -> float: pass
Expand Down
2 changes: 0 additions & 2 deletions test-data/unit/lib-stub/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ class type:

# These are provided here for convenience.
class int:
# Note: this is a simplification of the actual signature
def __init__(self, x: object = ..., base: int = ...) -> None: pass
def __add__(self, other: int) -> int: pass
def __rmul__(self, other: int) -> int: pass
class float: pass
Expand Down