Skip to content

Commit

Permalink
Simplify logic with overriding callee in cases of TypedDict.get funct…
Browse files Browse the repository at this point in the history
…ions.

After poking around with this a bunch today I realized it would be much simplier to simply create
a context-specific Callable as opposed to attemping to hijack the rest of the typechecking.

The original implementation had problems in places, for example where a TypedDict had a List field.
A default empty list was not being coerced correctly.
  • Loading branch information
Roy Williams committed Jan 12, 2017
1 parent c5f7481 commit 8c5975e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
36 changes: 18 additions & 18 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,24 @@ def check_call(self, callee: Type, args: List[Expression],
callee.type_object().name(), type.abstract_attributes,
context)

if isinstance(callee, TypedDictGetFunction):
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
arg_types = callee.arg_types
if len(args) == 1:
return_type = UnionType.make_union([
return_type, NoneTyp()])
elif isinstance(return_type, TypedDictType) and len(callee.arg_types) == 2:
# Explicitly set the type of the default parameter to
# Union[typing.Mapping, <return type>] in cases where the return value
# is a typed dict. This special case allows for chaining of `get` methods
# when accessing elements deep within nested dictionaries in a safe and
# concise way without having to set up exception handlers.
arg_types = [callee.arg_types[0],
UnionType.make_union([return_type,
self.named_type('typing.Mapping')])]
callee = callee.copy_modified(ret_type=return_type, arg_types=arg_types)

formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
callee.arg_kinds, callee.arg_names,
Expand All @@ -362,24 +380,6 @@ def check_call(self, callee: Type, args: List[Expression],
arg_types = self.infer_arg_types_in_context2(
callee, args, arg_kinds, formal_to_actual)

if isinstance(callee, TypedDictGetFunction):
if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)):
return_type = self.get_typeddict_index_type(callee.typed_dict, args[0])
if len(args) == 1:
return_type = UnionType.make_union([
return_type, NoneTyp()])
else:
# Explicitly set the return type to be a the TypedDict in cases where the
# call site is of the form `x.get('key', {})` and x['key'] is another
# TypedDict. This special case allows for chaining of `get` methods when
# accessing elements deep within nested dictionaries in a safe and
# concise way without having to set up exception handlers.
if not (isinstance(return_type, TypedDictType) and
is_subtype(arg_types[1], self.named_type('typing.Mapping'))):
return_type = UnionType.make_simplified_union(
[return_type, arg_types[1]])
return return_type, callee

self.check_argument_count(callee, arg_types, arg_kinds,
arg_names, formal_to_actual, context, self.msg)

Expand Down
14 changes: 10 additions & 4 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,6 @@ p = TaggedPoint(type='2d', x=42, y=1337)
reveal_type(p.get('type')) # E: Revealed type is 'Union[builtins.str, builtins.None]'
reveal_type(p.get('x')) # E: Revealed type is 'Union[builtins.int, builtins.None]'
reveal_type(p.get('y', 0)) # E: Revealed type is 'builtins.int'
reveal_type(p.get('y', 'hello')) # E: Revealed type is 'Union[builtins.int, builtins.str]'
reveal_type(p.get('y', {})) # E: Revealed type is 'Union[builtins.int, builtins.dict[builtins.None, builtins.None]]'
[builtins fixtures/dict.pyi]

[case testDefaultParameterStillTypeChecked]
Expand Down Expand Up @@ -472,12 +470,12 @@ p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
reveal_type(p.get('first_point', {}).get('x', 0)) # E: Revealed type is 'builtins.int'
[builtins fixtures/dict.pyi]

[case testChainedGetMethodWithNonDictFallback]
[case testGetMethodInvalidDefaultType]
from mypy_extensions import TypedDict
TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int})
PointSet = TypedDict('PointSet', {'first_point': TaggedPoint})
p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337))
p.get('first_point', 32).get('x', 0) # E: Some element of union has no attribute "get"
p.get('first_point', 32) # E: Argument 2 to "get" of "Mapping" has incompatible type "int"; expected "Union[TaggedPoint, Mapping]"
[builtins fixtures/dict.pyi]

[case testGetMethodOnList]
Expand All @@ -489,6 +487,14 @@ p = PointSet(points=[TaggedPoint(type='2d', x=42, y=1337)])
reveal_type(p.get('points', [])) # E: Revealed type is 'builtins.list[TypedDict(type=builtins.str, x=builtins.int, y=builtins.int, _fallback=__main__.TaggedPoint)]'
[builtins fixtures/dict.pyi]

[case testGetMethodWithListOfStrUnifies]
from typing import List
from mypy_extensions import TypedDict
Items = TypedDict('Items', {'name': str, 'values': List[str]})
def foo(i: Items) -> None:
reveal_type(i.get('values', [])) # E: Revealed type is 'builtins.list[builtins.str]'
[builtins fixtures/dict.pyi]

[case testDictGetMethodStillCallable]
from typing import Callable
from mypy_extensions import TypedDict
Expand Down

0 comments on commit 8c5975e

Please sign in to comment.