diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 833664f46b24d..fb23c35900884 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -347,6 +347,24 @@ def check_call(self, callee: Type, args: List[Expression], callee.type_object().name(), type.abstract_attributes, context) + if isinstance(callee, TypedDictGetFunction): + if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)): + return_type = self.get_typeddict_index_type(callee.typed_dict, args[0]) + arg_types = callee.arg_types + if len(args) == 1: + return_type = UnionType.make_union([ + return_type, NoneTyp()]) + elif isinstance(return_type, TypedDictType) and len(callee.arg_types) == 2: + # Explicitly set the type of the default parameter to + # Union[typing.Mapping, ] in cases where the return value + # is a typed dict. This special case allows for chaining of `get` methods + # when accessing elements deep within nested dictionaries in a safe and + # concise way without having to set up exception handlers. + arg_types = [callee.arg_types[0], + UnionType.make_union([return_type, + self.named_type('typing.Mapping')])] + callee = callee.copy_modified(ret_type=return_type, arg_types=arg_types) + formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, callee.arg_kinds, callee.arg_names, @@ -362,24 +380,6 @@ def check_call(self, callee: Type, args: List[Expression], arg_types = self.infer_arg_types_in_context2( callee, args, arg_kinds, formal_to_actual) - if isinstance(callee, TypedDictGetFunction): - if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)): - return_type = self.get_typeddict_index_type(callee.typed_dict, args[0]) - if len(args) == 1: - return_type = UnionType.make_union([ - return_type, NoneTyp()]) - else: - # Explicitly set the return type to be a the TypedDict in cases where the - # call site is of the form `x.get('key', {})` and x['key'] is another - # TypedDict. This special case allows for chaining of `get` methods when - # accessing elements deep within nested dictionaries in a safe and - # concise way without having to set up exception handlers. - if not (isinstance(return_type, TypedDictType) and - is_subtype(arg_types[1], self.named_type('typing.Mapping'))): - return_type = UnionType.make_simplified_union( - [return_type, arg_types[1]]) - return return_type, callee - self.check_argument_count(callee, arg_types, arg_kinds, arg_names, formal_to_actual, context, self.msg) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 4e3745913ae4d..0cff748afa7f9 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -438,8 +438,6 @@ p = TaggedPoint(type='2d', x=42, y=1337) reveal_type(p.get('type')) # E: Revealed type is 'Union[builtins.str, builtins.None]' reveal_type(p.get('x')) # E: Revealed type is 'Union[builtins.int, builtins.None]' reveal_type(p.get('y', 0)) # E: Revealed type is 'builtins.int' -reveal_type(p.get('y', 'hello')) # E: Revealed type is 'Union[builtins.int, builtins.str]' -reveal_type(p.get('y', {})) # E: Revealed type is 'Union[builtins.int, builtins.dict[builtins.None, builtins.None]]' [builtins fixtures/dict.pyi] [case testDefaultParameterStillTypeChecked] @@ -472,12 +470,12 @@ p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337)) reveal_type(p.get('first_point', {}).get('x', 0)) # E: Revealed type is 'builtins.int' [builtins fixtures/dict.pyi] -[case testChainedGetMethodWithNonDictFallback] +[case testGetMethodInvalidDefaultType] from mypy_extensions import TypedDict TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) PointSet = TypedDict('PointSet', {'first_point': TaggedPoint}) p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337)) -p.get('first_point', 32).get('x', 0) # E: Some element of union has no attribute "get" +p.get('first_point', 32) # E: Argument 2 to "get" of "Mapping" has incompatible type "int"; expected "Union[TaggedPoint, Mapping]" [builtins fixtures/dict.pyi] [case testGetMethodOnList] @@ -489,6 +487,14 @@ p = PointSet(points=[TaggedPoint(type='2d', x=42, y=1337)]) reveal_type(p.get('points', [])) # E: Revealed type is 'builtins.list[TypedDict(type=builtins.str, x=builtins.int, y=builtins.int, _fallback=__main__.TaggedPoint)]' [builtins fixtures/dict.pyi] +[case testGetMethodWithListOfStrUnifies] +from typing import List +from mypy_extensions import TypedDict +Items = TypedDict('Items', {'name': str, 'values': List[str]}) +def foo(i: Items) -> None: + reveal_type(i.get('values', [])) # E: Revealed type is 'builtins.list[builtins.str]' +[builtins fixtures/dict.pyi] + [case testDictGetMethodStillCallable] from typing import Callable from mypy_extensions import TypedDict