Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement type-aware get for TypedDict #2620

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PartialType, DeletedType, UnboundType, UninhabitedType, TypeType,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
get_typ_args, set_typ_args,
)
TypedDictGetFunction)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
Expand Down Expand Up @@ -347,6 +347,24 @@ def check_call(self, callee: Type, args: List[Expression],
callee.type_object().name(), type.abstract_attributes,
context)

if isinstance(callee, TypedDictGetFunction):
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
arg_types = callee.arg_types
if len(args) == 1:
return_type = UnionType.make_union([
return_type, NoneTyp()])
elif isinstance(return_type, TypedDictType) and len(callee.arg_types) == 2:
# Explicitly set the type of the default parameter to
# Union[typing.Mapping, <return type>] in cases where the return value
# is a typed dict. This special case allows for chaining of `get` methods
# when accessing elements deep within nested dictionaries in a safe and
# concise way without having to set up exception handlers.
arg_types = [callee.arg_types[0],
UnionType.make_union([return_type,
self.named_type('typing.Mapping')])]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would better to use the type such as Mapping[str, Any]. Not having type arguments for Mapping may cause trouble in some cases (e.g. index errors). This still wouldn't be quite safe, but let's deal with that as a separate issue.

callee = callee.copy_modified(ret_type=return_type, arg_types=arg_types)

formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
callee.arg_kinds, callee.arg_names,
Expand Down Expand Up @@ -1484,11 +1502,13 @@ def _get_value(self, index: Expression) -> Optional[int]:
return None

def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type:
return self.get_typeddict_index_type(td_type, index)

def get_typeddict_index_type(self, td_type: TypedDictType, index: Expression) -> Type:
if not isinstance(index, (StrExpr, UnicodeExpr)):
self.msg.typeddict_item_name_must_be_string_literal(td_type, index)
return AnyType()
item_name = index.value

item_type = td_type.items.get(item_name)
if item_type is None:
self.msg.typeddict_item_name_not_found(td_type, item_name, index)
Expand Down
13 changes: 8 additions & 5 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from mypy.types import (
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef,
Overloaded, TypeVarType, UnionType, PartialType,
DeletedType, NoneTyp, TypeType, function_type
)
DeletedType, NoneTyp, TypeType, function_type,
TypedDictGetFunction)
from mypy.nodes import (
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr,
ARG_POS, ARG_STAR, ARG_STAR2,
Expand Down Expand Up @@ -120,9 +120,12 @@ def analyze_member_access(name: str,
original_type=original_type, chk=chk)
elif isinstance(typ, TypedDictType):
# Actually look up from the fallback instance type.
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
is_operator, builtin_type, not_ready_callback, msg,
original_type=original_type, chk=chk)
result = analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
is_operator, builtin_type, not_ready_callback, msg,
original_type=original_type, chk=chk)
if name == 'get' and isinstance(result, CallableType):
result = TypedDictGetFunction(typ, result)
return result
elif isinstance(typ, FunctionLike) and typ.is_type_obj():
# Class attribute.
# TODO super?
Expand Down
20 changes: 20 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,26 @@ def zipall(self, right: 'TypedDictType') \
yield (item_name, None, right_item_type)


class TypedDictGetFunction(CallableType):
"""A special callable type containing a reference to the TypedDict `get` callable instance.
This is needed to delay determining the signature of a TypedDict's `get` method until the
method is actually called. This allows `get` to behave just as indexing into the TypedDict
would.

This is not a real type, but is needed to allow TypedDict.get to behave as expected.
"""
def __init__(self, typed_dict: TypedDictType, fallback_callable: CallableType) -> None:
super().__init__(fallback_callable.arg_types, fallback_callable.arg_kinds,
fallback_callable.arg_names, fallback_callable.ret_type,
fallback_callable.fallback, fallback_callable.name,
fallback_callable.definition, fallback_callable.variables,
fallback_callable.line, fallback_callable.column,
fallback_callable.is_ellipsis_args, fallback_callable.implicit,
fallback_callable.is_classmethod_class, fallback_callable.special_sig)
self.typed_dict = typed_dict
self.fallback_callable = fallback_callable


class StarType(Type):
"""The star type *type_parameter.

Expand Down
84 changes: 84 additions & 0 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,90 @@ def set_coordinate(p: TaggedPoint, key: str, value: int) -> None:

-- Special Method: get

[case testCanUseGetMethodWithStringLiteralKey]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
reveal_type(p.get('type')) # E: Revealed type is 'Union[builtins.str, builtins.None]'
reveal_type(p.get('x')) # E: Revealed type is 'Union[builtins.int, builtins.None]'
reveal_type(p.get('y', 0)) # E: Revealed type is 'builtins.int'
[builtins fixtures/dict.pyi]

[case testDefaultParameterStillTypeChecked]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
p.get('x', 1 + 'y') # E: Unsupported operand types for + ("int" and "str")
[builtins fixtures/dict.pyi]

[case testCannotGetMethodWithInvalidStringLiteralKey]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
p.get('z') # E: 'z' is not a valid item name; expected one of ['type', 'x', 'y']
[builtins fixtures/dict.pyi]

[case testGetMethodWithVariableKeyFallsBack]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
key = 'type'
reveal_type(p.get(key)) # E: Revealed type is 'builtins.object*'
[builtins fixtures/dict.pyi]

[case testChainedGetMethodWithDictFallback]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
reveal_type(p.get('first_point', {}).get('x', 0)) # E: Revealed type is 'builtins.int'
[builtins fixtures/dict.pyi]

[case testGetMethodInvalidDefaultType]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
p.get('first_point', 32) # E: Argument 2 to "get" of "Mapping" has incompatible type "int"; expected "Union[TypedDict(type=str, x=int, y=int), Mapping]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty minor thing, but the error message is a little confusing: it says "get" of "Mapping" even though the signature of get in this particular case comes from PointSet. A better message would be ... "get" of "PointSet" ... or `... "get" of a TypedDict ..." (if the name of the typed dict is not available).

(If this seems hard to do, we can create as a separate issue for this.)

[builtins fixtures/dict.pyi]

[case testGetMethodOnList]
from typing import List
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
PointSet = TypedDict('PointSet', {'points': List[TaggedPoint]})
p = PointSet(points=[TaggedPoint(type='2d', x=42, y=1337)])
reveal_type(p.get('points', [])) # E: Revealed type is 'builtins.list[TypedDict(type=builtins.str, x=builtins.int, y=builtins.int, _fallback=__main__.TaggedPoint)]'
[builtins fixtures/dict.pyi]

[case testGetMethodWithListOfStrUnifies]
from typing import List
from mypy_extensions import TypedDict
Items = TypedDict('Items', {'name': str, 'values': List[str]})
def foo(i: Items) -> None:
reveal_type(i.get('values', [])) # E: Revealed type is 'builtins.list[builtins.str]'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: using 2 spaces for indent here.

[builtins fixtures/dict.pyi]

[case testDictGetMethodStillCallable]
from typing import Callable
from mypy_extensions import TypedDict
Point = TypedDict('Point', {'x': int, 'y': int})
p = Point(x=42, y=13)
def invoke_method(method: Callable[[str, int], int]) -> None:
pass
invoke_method(p.get)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test case where p.get does not have a compatible type (e.g. target function expects Callable[[int, str], int] or something).

[builtins fixtures/dict.pyi]

[case testDictGetMethodStillCallableWithObject]
from typing import Callable
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
p = TaggedPoint(type='2d', x=42, y=1337)
def invoke_method(method: Callable[..., object]) -> None:
pass
invoke_method(p.get)
[builtins fixtures/dict.pyi]

-- TODO: Implement support for these cases:
--[case testGetOfTypedDictWithValidStringLiteralKeyReturnsPreciseType]
--[case testGetOfTypedDictWithInvalidStringLiteralKeyIsError]
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]):
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
def __setitem__(self, k: KT, v: VT) -> None: pass
def __iter__(self) -> Iterator[KT]: pass
def get(self, k: KT, default: VT=None) -> VT: pass
def update(self, a: Mapping[KT, VT]) -> None: pass

class int: # for convenience
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/lib-stub/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ class Sequence(Iterable[T], Generic[T]):
@abstractmethod
def __getitem__(self, n: Any) -> T: pass

class Mapping(Generic[T, U]): pass
class Mapping(Generic[T, U]):
@abstractmethod
def get(self, k: T, default: U=None) -> U: pass

class MutableMapping(Generic[T, U]): pass

Expand Down