Skip to content

Commit

Permalink
gh-119180: Add evaluate functions for type params and type aliases (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jul 27, 2024
1 parent cbac8a3 commit ae19226
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 159 deletions.
1 change: 1 addition & 0 deletions Include/internal/pycore_global_objects.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct _Py_interp_cached_objects {
PyTypeObject *paramspec_type;
PyTypeObject *paramspecargs_type;
PyTypeObject *paramspeckwargs_type;
PyTypeObject *constevaluator_type;
};

#define _Py_INTERP_STATIC_OBJECT(interp, NAME) \
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_typevarobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ extern PyObject *_Py_subscript_generic(PyThreadState *, PyObject *);
extern PyObject *_Py_set_typeparam_default(PyThreadState *, PyObject *, PyObject *);
extern int _Py_initialize_generic(PyInterpreterState *);
extern void _Py_clear_generic_types(PyInterpreterState *);
extern int _Py_typing_type_repr(PyUnicodeWriter *, PyObject *);

extern PyTypeObject _PyTypeAlias_Type;
extern PyTypeObject _PyNoDefault_Type;
Expand Down
19 changes: 16 additions & 3 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,16 @@ def __missing__(self, key):
return fwdref


def call_annotate_function(annotate, format, owner=None):
def call_evaluate_function(evaluate, format, *, owner=None):
"""Call an evaluate function. Evaluate functions are normally generated for
the value of type aliases and the bounds, constraints, and defaults of
type parameter objects.
"""
return call_annotate_function(evaluate, format, owner=owner, _is_evaluate=True)


def call_annotate_function(annotate, format, *, owner=None,
_is_evaluate=False):
"""Call an __annotate__ function. __annotate__ functions are normally
generated by the compiler to defer the evaluation of annotations. They
can be called with any of the format arguments in the Format enum, but
Expand Down Expand Up @@ -459,8 +468,11 @@ def call_annotate_function(annotate, format, owner=None):
closure = tuple(new_closure)
else:
closure = None
func = types.FunctionType(annotate.__code__, globals, closure=closure)
func = types.FunctionType(annotate.__code__, globals, closure=closure,
argdefs=annotate.__defaults__, kwdefaults=annotate.__kwdefaults__)
annos = func(Format.VALUE)
if _is_evaluate:
return annos if isinstance(annos, str) else repr(annos)
return {
key: val if isinstance(val, str) else repr(val)
for key, val in annos.items()
Expand Down Expand Up @@ -511,7 +523,8 @@ def call_annotate_function(annotate, format, owner=None):
closure = tuple(new_closure)
else:
closure = None
func = types.FunctionType(annotate.__code__, globals, closure=closure)
func = types.FunctionType(annotate.__code__, globals, closure=closure,
argdefs=annotate.__defaults__, kwdefaults=annotate.__kwdefaults__)
result = func(Format.VALUE)
for obj in globals.stringifiers:
obj.__class__ = ForwardRef
Expand Down
19 changes: 19 additions & 0 deletions Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,25 @@ def test_pep_695_generics_with_future_annotations_nested_in_function(self):
)


class TestCallEvaluateFunction(unittest.TestCase):
def test_evaluation(self):
def evaluate(format, exc=NotImplementedError):
if format != 1:
raise exc
return undefined

with self.assertRaises(NameError):
annotationlib.call_evaluate_function(evaluate, annotationlib.Format.VALUE)
self.assertEqual(
annotationlib.call_evaluate_function(evaluate, annotationlib.Format.FORWARDREF),
annotationlib.ForwardRef("undefined"),
)
self.assertEqual(
annotationlib.call_evaluate_function(evaluate, annotationlib.Format.SOURCE),
"undefined",
)


class MetaclassTests(unittest.TestCase):
def test_annotated_meta(self):
class Meta(type):
Expand Down
43 changes: 42 additions & 1 deletion Lib/test/test_type_params.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import annotationlib
import asyncio
import textwrap
import types
Expand All @@ -6,7 +7,7 @@
import weakref
from test.support import requires_working_socket, check_syntax_error, run_code

from typing import Generic, NoDefault, Sequence, TypeVar, TypeVarTuple, ParamSpec, get_args
from typing import Generic, NoDefault, Sequence, TypeAliasType, TypeVar, TypeVarTuple, ParamSpec, get_args


class TypeParamsInvalidTest(unittest.TestCase):
Expand Down Expand Up @@ -1394,3 +1395,43 @@ def test_symtable_key_regression_name(self):

self.assertEqual(ns["X1"].__type_params__[0].__default__, "A")
self.assertEqual(ns["X2"].__type_params__[0].__default__, "B")


class TestEvaluateFunctions(unittest.TestCase):
def test_general(self):
type Alias = int
Alias2 = TypeAliasType("Alias2", int)
def f[T: int = int, **P = int, *Ts = int](): pass
T, P, Ts = f.__type_params__
T2 = TypeVar("T2", bound=int, default=int)
P2 = ParamSpec("P2", default=int)
Ts2 = TypeVarTuple("Ts2", default=int)
cases = [
Alias.evaluate_value,
Alias2.evaluate_value,
T.evaluate_bound,
T.evaluate_default,
P.evaluate_default,
Ts.evaluate_default,
T2.evaluate_bound,
T2.evaluate_default,
P2.evaluate_default,
Ts2.evaluate_default,
]
for case in cases:
with self.subTest(case=case):
self.assertIs(case(1), int)
self.assertIs(annotationlib.call_evaluate_function(case, annotationlib.Format.VALUE), int)
self.assertIs(annotationlib.call_evaluate_function(case, annotationlib.Format.FORWARDREF), int)
self.assertEqual(annotationlib.call_evaluate_function(case, annotationlib.Format.SOURCE), 'int')

def test_constraints(self):
def f[T: (int, str)](): pass
T, = f.__type_params__
T2 = TypeVar("T2", int, str)
for case in [T, T2]:
with self.subTest(case=case):
self.assertEqual(case.evaluate_constraints(1), (int, str))
self.assertEqual(annotationlib.call_evaluate_function(case.evaluate_constraints, annotationlib.Format.VALUE), (int, str))
self.assertEqual(annotationlib.call_evaluate_function(case.evaluate_constraints, annotationlib.Format.FORWARDREF), (int, str))
self.assertEqual(annotationlib.call_evaluate_function(case.evaluate_constraints, annotationlib.Format.SOURCE), '(int, str)')
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
As part of :pep:`749`, add the following attributes for customizing
evaluation of annotation scopes:

* ``evaluate_value`` on :class:`typing.TypeAliasType`
* ``evaluate_bound``, ``evaluate_constraints``, and ``evaluate_default`` on :class:`typing.TypeVar`
* ``evaluate_default`` on :class:`typing.ParamSpec`
* ``evaluate_default`` on :class:`typing.TypeVarTuple`
70 changes: 4 additions & 66 deletions Objects/genericaliasobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_modsupport.h" // _PyArg_NoKeywords()
#include "pycore_object.h"
#include "pycore_typevarobject.h" // _Py_typing_type_repr
#include "pycore_unionobject.h" // _Py_union_type_or, _PyGenericAlias_Check


Expand Down Expand Up @@ -50,69 +51,6 @@ ga_traverse(PyObject *self, visitproc visit, void *arg)
return 0;
}

static int
ga_repr_item(PyUnicodeWriter *writer, PyObject *p)
{
PyObject *qualname = NULL;
PyObject *module = NULL;
int rc;

if (p == Py_Ellipsis) {
// The Ellipsis object
rc = PyUnicodeWriter_WriteUTF8(writer, "...", 3);
goto done;
}

if ((rc = PyObject_HasAttrWithError(p, &_Py_ID(__origin__))) > 0 &&
(rc = PyObject_HasAttrWithError(p, &_Py_ID(__args__))) > 0)
{
// It looks like a GenericAlias
goto use_repr;
}
if (rc < 0) {
goto error;
}

if (PyObject_GetOptionalAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
goto error;
}
if (qualname == NULL) {
goto use_repr;
}
if (PyObject_GetOptionalAttr(p, &_Py_ID(__module__), &module) < 0) {
goto error;
}
if (module == NULL || module == Py_None) {
goto use_repr;
}

// Looks like a class
if (PyUnicode_Check(module) &&
_PyUnicode_EqualToASCIIString(module, "builtins"))
{
// builtins don't need a module name
rc = PyUnicodeWriter_WriteStr(writer, qualname);
goto done;
}
else {
rc = PyUnicodeWriter_Format(writer, "%S.%S", module, qualname);
goto done;
}

error:
rc = -1;
goto done;

use_repr:
rc = PyUnicodeWriter_WriteRepr(writer, p);
goto done;

done:
Py_XDECREF(qualname);
Py_XDECREF(module);
return rc;
}

static int
ga_repr_items_list(PyUnicodeWriter *writer, PyObject *p)
{
Expand All @@ -131,7 +69,7 @@ ga_repr_items_list(PyUnicodeWriter *writer, PyObject *p)
}
}
PyObject *item = PyList_GET_ITEM(p, i);
if (ga_repr_item(writer, item) < 0) {
if (_Py_typing_type_repr(writer, item) < 0) {
return -1;
}
}
Expand Down Expand Up @@ -162,7 +100,7 @@ ga_repr(PyObject *self)
goto error;
}
}
if (ga_repr_item(writer, alias->origin) < 0) {
if (_Py_typing_type_repr(writer, alias->origin) < 0) {
goto error;
}
if (PyUnicodeWriter_WriteChar(writer, '[') < 0) {
Expand All @@ -181,7 +119,7 @@ ga_repr(PyObject *self)
goto error;
}
}
else if (ga_repr_item(writer, p) < 0) {
else if (_Py_typing_type_repr(writer, p) < 0) {
goto error;
}
}
Expand Down
Loading

0 comments on commit ae19226

Please sign in to comment.