diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 0a67742a2b3081..be3bc275817f50 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -1,8 +1,10 @@ """Helpers for introspecting and wrapping annotations.""" import ast +import builtins import enum import functools +import keyword import sys import types @@ -154,8 +156,19 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): globals[param_name] = param locals.pop(param_name, None) - code = self.__forward_code__ - value = eval(code, globals=globals, locals=locals) + arg = self.__forward_arg__ + if arg.isidentifier() and not keyword.iskeyword(arg): + if arg in locals: + value = locals[arg] + elif arg in globals: + value = globals[arg] + elif hasattr(builtins, arg): + return getattr(builtins, arg) + else: + raise NameError(arg) + else: + code = self.__forward_code__ + value = eval(code, globals=globals, locals=locals) self.__forward_evaluated__ = True self.__forward_value__ = value return value @@ -254,7 +267,9 @@ class _Stringifier: __slots__ = _SLOTS def __init__(self, node, globals=None, owner=None, is_class=False, cell=None): - assert isinstance(node, ast.AST) + # Either an AST node or a simple str (for the common case where a ForwardRef + # represent a single name). + assert isinstance(node, (ast.AST, str)) self.__arg__ = None self.__forward_evaluated__ = False self.__forward_value__ = None @@ -267,18 +282,26 @@ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None): self.__cell__ = cell self.__owner__ = owner - def __convert(self, other): + def __convert_to_ast(self, other): if isinstance(other, _Stringifier): + if isinstance(other.__ast_node__, str): + return ast.Name(id=other.__ast_node__) return other.__ast_node__ elif isinstance(other, slice): return ast.Slice( - lower=self.__convert(other.start) if other.start is not None else None, - upper=self.__convert(other.stop) if other.stop is not None else None, - step=self.__convert(other.step) if other.step is not None else None, + lower=self.__convert_to_ast(other.start) if other.start is not None else None, + upper=self.__convert_to_ast(other.stop) if other.stop is not None else None, + step=self.__convert_to_ast(other.step) if other.step is not None else None, ) else: return ast.Constant(value=other) + def __get_ast(self): + node = self.__ast_node__ + if isinstance(node, str): + return ast.Name(id=node) + return node + def __make_new(self, node): return _Stringifier( node, self.__globals__, self.__owner__, self.__forward_is_class__ @@ -292,38 +315,37 @@ def __hash__(self): def __getitem__(self, other): # Special case, to avoid stringifying references to class-scoped variables # as '__classdict__["x"]'. - if ( - isinstance(self.__ast_node__, ast.Name) - and self.__ast_node__.id == "__classdict__" - ): + if self.__ast_node__ == "__classdict__": raise KeyError if isinstance(other, tuple): - elts = [self.__convert(elt) for elt in other] + elts = [self.__convert_to_ast(elt) for elt in other] other = ast.Tuple(elts) else: - other = self.__convert(other) + other = self.__convert_to_ast(other) assert isinstance(other, ast.AST), repr(other) - return self.__make_new(ast.Subscript(self.__ast_node__, other)) + return self.__make_new(ast.Subscript(self.__get_ast(), other)) def __getattr__(self, attr): - return self.__make_new(ast.Attribute(self.__ast_node__, attr)) + return self.__make_new(ast.Attribute(self.__get_ast(), attr)) def __call__(self, *args, **kwargs): return self.__make_new( ast.Call( - self.__ast_node__, - [self.__convert(arg) for arg in args], + self.__get_ast(), + [self.__convert_to_ast(arg) for arg in args], [ - ast.keyword(key, self.__convert(value)) + ast.keyword(key, self.__convert_to_ast(value)) for key, value in kwargs.items() ], ) ) def __iter__(self): - yield self.__make_new(ast.Starred(self.__ast_node__)) + yield self.__make_new(ast.Starred(self.__get_ast())) def __repr__(self): + if isinstance(self.__ast_node__, str): + return self.__ast_node__ return ast.unparse(self.__ast_node__) def __format__(self, format_spec): @@ -332,7 +354,7 @@ def __format__(self, format_spec): def _make_binop(op: ast.AST): def binop(self, other): return self.__make_new( - ast.BinOp(self.__ast_node__, op, self.__convert(other)) + ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other)) ) return binop @@ -356,7 +378,7 @@ def binop(self, other): def _make_rbinop(op: ast.AST): def rbinop(self, other): return self.__make_new( - ast.BinOp(self.__convert(other), op, self.__ast_node__) + ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast()) ) return rbinop @@ -381,9 +403,9 @@ def _make_compare(op): def compare(self, other): return self.__make_new( ast.Compare( - left=self.__ast_node__, + left=self.__get_ast(), ops=[op], - comparators=[self.__convert(other)], + comparators=[self.__convert_to_ast(other)], ) ) @@ -400,7 +422,7 @@ def compare(self, other): def _make_unary_op(op): def unary_op(self): - return self.__make_new(ast.UnaryOp(op, self.__ast_node__)) + return self.__make_new(ast.UnaryOp(op, self.__get_ast())) return unary_op @@ -422,7 +444,7 @@ def __init__(self, namespace, globals=None, owner=None, is_class=False): def __missing__(self, key): fwdref = _Stringifier( - ast.Name(id=key), + key, globals=self.globals, owner=self.owner, is_class=self.is_class, @@ -480,7 +502,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): name = freevars[i] else: name = "__cell__" - fwdref = _Stringifier(ast.Name(id=name)) + fwdref = _Stringifier(name) new_closure.append(types.CellType(fwdref)) closure = tuple(new_closure) else: @@ -532,7 +554,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): else: name = "__cell__" fwdref = _Stringifier( - ast.Name(id=name), + name, cell=cell, owner=owner, globals=annotate.__globals__, @@ -555,6 +577,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): result = func(Format.VALUE) for obj in globals.stringifiers: obj.__class__ = ForwardRef + if isinstance(obj.__ast_node__, str): + obj.__arg__ = obj.__ast_node__ + obj.__ast_node__ = None return result elif format == Format.VALUE: # Should be impossible because __annotate__ functions must not raise diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index dd8ceb55a411fb..cc051ef3b93658 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -1,6 +1,7 @@ """Tests for the annotations module.""" import annotationlib +import builtins import collections import functools import itertools @@ -280,7 +281,14 @@ class Gen[T]: def test_fwdref_with_module(self): self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format) - self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter) + self.assertIs( + ForwardRef("Counter", module="collections").evaluate(), + collections.Counter + ) + self.assertEqual( + ForwardRef("Counter[int]", module="collections").evaluate(), + collections.Counter[int], + ) with self.assertRaises(NameError): # If globals are passed explicitly, we don't look at the module dict @@ -305,6 +313,33 @@ def test_fwdref_value_is_cached(self): self.assertIs(fr.evaluate(globals={"hello": str}), str) self.assertIs(fr.evaluate(), str) + def test_fwdref_with_owner(self): + self.assertEqual( + ForwardRef("Counter[int]", owner=collections).evaluate(), + collections.Counter[int], + ) + + def test_name_lookup_without_eval(self): + # test the codepath where we look up simple names directly in the + # namespaces without going through eval() + self.assertIs(ForwardRef("int").evaluate(), int) + self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str) + self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float) + self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str) + with support.swap_attr(builtins, "int", dict): + self.assertIs(ForwardRef("int").evaluate(), dict) + + with self.assertRaises(NameError): + ForwardRef("doesntexist").evaluate() + + def test_fwdref_invalid_syntax(self): + fr = ForwardRef("if") + with self.assertRaises(SyntaxError): + fr.evaluate() + fr = ForwardRef("1+") + with self.assertRaises(SyntaxError): + fr.evaluate() + class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self):