diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 28cbbb6a861b..a0a4fa90e122 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1848,13 +1848,14 @@ def visit_dict_expr(self, e: DictExpr) -> Type: # an error, but returns the TypedDict type that matches the literal it found # that would cause a second error when that TypedDict type is returned upstream # to avoid the second error, we always return TypedDict type that was requested - if isinstance(self.type_context[-1], TypedDictType): + typeddict_context = self.find_typeddict_context(self.type_context[-1]) + if typeddict_context: self.check_typeddict_call_with_dict( - callee=self.type_context[-1], + callee=typeddict_context, kwargs=e, context=e ) - return self.type_context[-1].copy_modified() + return typeddict_context.copy_modified() # Collect function arguments, watching out for **expr. args = [] # type: List[Expression] # Regular "key: value" @@ -1905,6 +1906,19 @@ def visit_dict_expr(self, e: DictExpr) -> Type: self.check_call(method, [arg], [nodes.ARG_POS], arg) return rv + def find_typeddict_context(self, context: Type) -> Optional[TypedDictType]: + if isinstance(context, TypedDictType): + return context + elif isinstance(context, UnionType): + items = [] + for item in context.items: + item_context = self.find_typeddict_context(item) + if item_context: + items.append(item_context) + if len(items) == 1: + return items[0] + return None + def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" inferred_type, type_override = self.infer_lambda_type_using_context(e) diff --git a/mypy/plugin.py b/mypy/plugin.py index c8fe39910d53..adc4074b4ce3 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -224,9 +224,6 @@ def typed_dict_get_callback( else: context.msg.typeddict_item_name_not_found(object_type, key, context.context) return AnyType() - else: - context.msg.typeddict_item_name_must_be_string_literal(object_type, context.context) - return AnyType() return inferred_return_type diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 568c1ff95d96..c29aad48af92 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -784,8 +784,8 @@ d.get('x', 1, 2) # E: No overload variant of "get" of "Mapping" matches argument x = d.get('z') # E: 'z' is not a valid item name; expected one of ['x', 'y'] reveal_type(x) # E: Revealed type is 'Any' s = '' -y = d.get(s) # E: Cannot prove expression is a valid item name; expected one of ['x', 'y'] -reveal_type(y) # E: Revealed type is 'Any' +y = d.get(s) +reveal_type(y) # E: Revealed type is 'builtins.object*' [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] @@ -795,3 +795,21 @@ D = TypedDict('D', {'x': int, 'y': str}) d: D d.bad(1) # E: "D" has no attribute "bad" [builtins fixtures/dict.pyi] + +[case testTypedDictChainedGetMethodWithDictFallback] +from mypy_extensions import TypedDict +D = TypedDict('D', {'x': int, 'y': str}) +E = TypedDict('E', {'d': D}) +p = E(d=D(x=0, y='')) +reveal_type(p.get('d', {'x': 1, 'y': ''})) # E: Revealed type is 'TypedDict(x=builtins.int, y=builtins.str, _fallback=__main__.D)' +p.get('d', {}) # E: Expected items ['x', 'y'] but found []. +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypedDictGetDefaultParameterStillTypeChecked] +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +p = TaggedPoint(type='2d', x=42, y=1337) +p.get('x', 1 + 'y') # E: Unsupported operand types for + ("int" and "str") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi]